diff --git a/.gitignore b/.gitignore index 68bc17f9..b14ebdad 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,4 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +docs/ diff --git a/docs/.gitignore b/docs/.gitignore deleted file mode 100644 index d6b7ef32..00000000 --- a/docs/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -* -!.gitignore diff --git a/shimmer/__init__.py b/shimmer/__init__.py index 6e82fe8b..cfc082e9 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -11,6 +11,7 @@ from shimmer.modules.global_workspace import ( GlobalWorkspace, GlobalWorkspaceBase, + GWPredictions, SchedulerArgs, VariationalGlobalWorkspace, pretrained_global_workspace, @@ -70,6 +71,7 @@ "GlobalWorkspaceBase", "VariationalGlobalWorkspace", "SchedulerArgs", + "GWPredictions", "pretrained_global_workspace", "RepeatedDataset", ] diff --git a/shimmer/dataset.py b/shimmer/dataset.py index 3dff1f89..96bc4ab3 100644 --- a/shimmer/dataset.py +++ b/shimmer/dataset.py @@ -13,7 +13,7 @@ class RepeatedDataset(Dataset): """ Dataset that cycles through its items to have a size of at least min size. If drop_last is True, the size will be exaclty min_size. If drop_last is False, - the min_size <= size < min_size + len(dataset). + the min_size ≤ size < min_size + len(dataset). """ def __init__( @@ -25,7 +25,7 @@ def __init__( """ Args: dataset (SizedDataset): dataset to repeat. The dataset should have a size - (__len__ defined). + (where `__len__` is defined). min_size (int): minimum size of the final dataset drop_last (bool): whether to remove overflow when repeating the dataset. @@ -43,7 +43,7 @@ def __init__( def __len__(self) -> int: """ Size of the dataset. Will be min_size if drop_last is True. - Otherwise, min_size <= size < min_size + len(dataset). + Otherwise, min_size ≤ size < min_size + len(dataset). """ return self.total_size diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index 4b639660..3191fbf2 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -10,6 +10,7 @@ from shimmer.modules.global_workspace import ( GlobalWorkspace, GlobalWorkspaceBase, + GWPredictions, SchedulerArgs, VariationalGlobalWorkspace, pretrained_global_workspace, @@ -65,5 +66,6 @@ "GlobalWorkspaceBase", "VariationalGlobalWorkspace", "SchedulerArgs", + "GWPredictions", "pretrained_global_workspace", ] diff --git a/shimmer/modules/contrastive_loss.py b/shimmer/modules/contrastive_loss.py index 4b814b7c..eaeeb2e1 100644 --- a/shimmer/modules/contrastive_loss.py +++ b/shimmer/modules/contrastive_loss.py @@ -10,6 +10,7 @@ ContrastiveLossType = Callable[[torch.Tensor, torch.Tensor], LossOutput] """Contrastive loss function type. + A function taking the prediction and targets and returning a LossOutput. """ @@ -17,6 +18,7 @@ [torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], LossOutput ] """Contrastive loss function type for variational GlobalWorkspace. + A function taking the prediction mean, prediction std, target mean and target std and returns a LossOutput. """ @@ -31,10 +33,10 @@ def info_nce( """InfoNCE loss Args: - x: prediction - y: target - logit_scale: logit scale - reduction: reduction to apply + x (`torch.Tensor`): prediction + y (`torch.Tensor`): target + logit_scale (`torch.Tensor`): logit scale + reduction (`Literal["mean", "sum", "none"]`): reduction to apply Returns: the InfoNCE loss """ @@ -54,10 +56,10 @@ def contrastive_loss( """CLIP-like contrastive loss Args: - x: prediction - y: target - logit_scale: logit scale - reduction: reduction to apply + x (`torch.Tensor`): prediction + y (`torch.Tensor`): target + logit_scale (`torch.Tensor`): logit scale + reduction (`Literal["mean", "sum", "none"]`): reduction to apply Returns: the contrastive loss """ @@ -82,12 +84,12 @@ def contrastive_loss_with_uncertainty( This is used in Variational Global Workspaces. Args: - x: prediction - x_log_uncertainty: logvar of the prediction - y: target - y_log_uncertainty: logvar of the target - logit_scale: logit scale - reduction: reduction to apply + x (`torch.Tensor`): prediction + x_log_uncertainty (`torch.Tensor`): logvar of the prediction + y (`torch.Tensor`): target + y_log_uncertainty (`torch.Tensor`): logvar of the target + logit_scale (`torch.Tensor`): logit scale + reduction (`Literal["mean", "sum", "none"]`): reduction to apply Returns: the contrastive loss with uncertainty. """ @@ -104,7 +106,7 @@ def contrastive_loss_with_uncertainty( class ContrastiveLoss(torch.nn.Module): - """CLIP-like ContrastiveLoss torch module""" + """CLIP-like ContrastiveLoss torch module.""" def __init__( self, @@ -115,10 +117,11 @@ def __init__( """Initializes a contrastive loss. Args: - logit_scale: logit_scale tensor. - reduction: reduction to apply to the loss. Defaults to "mean" - learn_logit_scale: whether to learn the logit_scale parameter. Defaults to - False. + logit_scale (`torch.Tensor`): logit_scale tensor. + reduction (`Literal["mean", "sum", "none"]`): reduction to apply to the + loss. Defaults to `"mean"`. + learn_logit_scale (`torch.Tensor`): whether to learn the `logit_scale` + parameter. Defaults to `False`. """ super().__init__() @@ -130,10 +133,11 @@ def __init__( self.reduction: Literal["mean", "sum", "none"] = reduction def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput: - """Computes the loss + """Computes the loss. + Args: - x: prediction - y: target + x (`torch.Tensor`): prediction + y (`torch.Tensor`): target Returns: LossOutput of the loss. Contains a `logit_scale` metric. @@ -146,6 +150,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput: class ContrastiveLossWithUncertainty(torch.nn.Module): """CLIP-like contrastive loss with uncertainty module. + This is used in Variational Global Workspaces. """ @@ -159,10 +164,11 @@ def __init__( ContrastiveLoss used for VariationalGlobalWorkspace Args: - logit_scale: logit_scale tensor. - reduction: reduction to apply to the loss. Defaults to "mean" - learn_logit_scale: whether to learn the logit_scale parameter. Defaults to - False. + logit_scale (`torch.Tensor`): logit_scale tensor. + reduction (`Literal["mean", "sum", "none"]`): reduction to apply to + the loss. Defaults to `"mean"`. + learn_logit_scale (`bool`): whether to learn the logit_scale parameter. + Defaults to `False`. """ super().__init__() @@ -188,8 +194,8 @@ def forward( y_log_uncertainty: target logvar Returns: - LossOutput of the loss. Contains a `logit_scale` metric and a - `no_uncertainty` metric with the classic contrastive loss computed without + LossOutput of the loss. Contains a `"logit_scale"` metric and a + `"no_uncertainty"` metric with the classic contrastive loss computed without the logvar information. """ return LossOutput( diff --git a/shimmer/modules/domain.py b/shimmer/modules/domain.py index 66a1d7bb..8b7869fb 100644 --- a/shimmer/modules/domain.py +++ b/shimmer/modules/domain.py @@ -33,7 +33,9 @@ def all(self) -> dict[str, torch.Tensor]: class DomainModule(pl.LightningModule): """ Base class for a DomainModule that defines domain specific modules of the GW. - We do not use ABC here because some modules could be without encore or decoder. + + Note: We do not use ABC here because some modules could + be without encore or decoder. """ def __init__( @@ -44,7 +46,7 @@ def __init__( Initializes a DomainModule. Args: - latent_dim: latent dimension of the unimodal module + latent_dim (`int`): latent dimension of the unimodal module """ super().__init__() @@ -54,23 +56,90 @@ def __init__( def encode(self, x: Any) -> torch.Tensor: """ Encode the domain data into a unimodal representation. + Args: - x: data of the domain. + x (`Any`): data of the domain. Returns: - a unimodal representation. + `torch.Tensor`: a unimodal representation. """ raise NotImplementedError def decode(self, z: torch.Tensor) -> Any: """ Decode data from unimodal representation back to the domain data. + Args: - z: unimodal representation of the domain. + z (`torch.Tensor`): unimodal representation of the domain. Returns: - the original domain data. + `Any`: the original domain data. """ raise NotImplementedError + def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + """Generic loss computation the modality. + + Args: + pred (`torch.Tensor`): prediction of the model + target (`torch.Tensor`): target tensor + Results: + `LossOutput`: LossOuput with training loss and additional metrics. + """ + raise NotImplementedError + + def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + """ + Computes the loss for a demi-cycle. Override if the demi-cycle loss is + different that the generic loss. + + Args: + pred (`torch.Tensor`): prediction of the model + target (`torch.Tensor`): target tensor + Results: + `LossOutput`: LossOuput with training loss and additional metrics. + """ + return self.compute_loss(pred, target) + + def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + """ + Computes the loss for a cycle. Override if the cycle loss is + different that the generic loss. + + Args: + pred (`torch.Tensor`): prediction of the model + target (`torch.Tensor`): target tensor + Results: + `LossOutput`: LossOuput with training loss and additional metrics. + """ + return self.compute_loss(pred, target) + + def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + """ + Computes the loss for a translation. Override if the translation loss is + different that the generic loss. + + Args: + pred (`torch.Tensor`): prediction of the model + target (`torch.Tensor`): target tensor + Results: + `LossOutput`: LossOuput with training loss and additional metrics. + """ + return self.compute_loss(pred, target) + + def compute_broadcast_loss( + self, pred: torch.Tensor, target: torch.Tensor + ) -> LossOutput: + """ + Computes the loss for a broadcast (fusion). Override if the broadcast loss is + different that the generic loss. + + Args: + pred (`torch.Tensor`): prediction of the model + target (`torch.Tensor`): target tensor + Results: + `LossOutput`: LossOuput with training loss and additional metrics. + """ + return self.compute_loss(pred, target) + def on_before_gw_encode_dcy(self, z: torch.Tensor) -> torch.Tensor: """Some additional computation to do before encoding the unimodal latent representation to the GW when doing a demi-cycle loss. @@ -78,10 +147,10 @@ def on_before_gw_encode_dcy(self, z: torch.Tensor) -> torch.Tensor: If not defined, will return the input (identity function). Args: - z: latent representation + z (`torch.Tensor`): latent representation Returns: - The updated latent representation + `torch.Tensor`: The updated latent representation """ return z @@ -92,10 +161,10 @@ def on_before_gw_encode_cont(self, z: torch.Tensor) -> torch.Tensor: If not defined, will return the input (identity function). Args: - z: latent representation + z (`torch.Tensor`): latent representation Returns: - The updated latent representation + `torch.Tensor`: The updated latent representation """ return z @@ -106,10 +175,10 @@ def on_before_gw_encode_tr(self, z: torch.Tensor) -> torch.Tensor: If not defined, will return the input (identity function). Args: - z: latent representation + z (`torch.Tensor`): latent representation Returns: - The updated latent representation + `torch.Tensor`: the updated latent representation """ return z @@ -120,10 +189,10 @@ def on_before_gw_encode_cy(self, z: torch.Tensor) -> torch.Tensor: If not defined, will return the input (identity function). Args: - z: latent representation + z (`torch.Tensor`): latent representation Returns: - The updated latent representation + `torch.Tensor`: the updated latent representation """ return z @@ -134,70 +203,9 @@ def on_before_gw_encode_broadcast(self, z: torch.Tensor) -> torch.Tensor: If not defined, will return the input (identity function). Args: - z: latent representation + z (`torch.Tensor`): latent representation Returns: - The updated latent representation + `torch.Tensor`: the updated latent representation """ return z - - def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: - """Generic loss computation the modality. - - Args: - pred: prediction of the model - target: target tensor - Results: - LossOuput with training loss and additional metrics. - """ - raise NotImplementedError - - def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: - """ - Computes the loss for a demi-cycle. Override if the demi-cycle loss is - different that the generic loss. - Args: - pred: prediction of the model - target: target tensor - Results: - LossOuput with training loss and additional metrics. - """ - return self.compute_loss(pred, target) - - def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: - """ - Computes the loss for a cycle. Override if the cycle loss is - different that the generic loss. - Args: - pred: prediction of the model - target: target tensor - Results: - LossOuput with training loss and additional metrics. - """ - return self.compute_loss(pred, target) - - def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: - """ - Computes the loss for a translation. Override if the translation loss is - different that the generic loss. - Args: - pred: prediction of the model - target: target tensor - Results: - LossOuput with training loss and additional metrics. - """ - return self.compute_loss(pred, target) - - def compute_broadcast_loss( - self, pred: torch.Tensor, target: torch.Tensor - ) -> LossOutput: - """ - Computes the loss for a broadcast (fusion). Override if the broadcast loss is - different that the generic loss. - Args: - pred: prediction of the model - target: target tensor - Results: - LossOuput with training loss and additional metrics. - """ - return self.compute_loss(pred, target) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index ca84f2e8..3035c859 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -1,6 +1,6 @@ from collections.abc import Iterable, Mapping from pathlib import Path -from typing import Any, TypedDict, cast +from typing import Any, Literal, TypedDict, cast import torch from lightning.pytorch import LightningModule @@ -26,8 +26,13 @@ GWLosses, GWLossesBase, GWLossesFusion, - LatentsT, + LatentsDomainGroupsDT, + LatentsDomainGroupsT, + LatentsDomainGroupT, LossCoefs, + RawDomainGroupsDT, + RawDomainGroupsT, + RawDomainGroupT, VariationalGWLosses, VariationalLossCoefs, ) @@ -44,20 +49,38 @@ class SchedulerArgs(TypedDict, total=False): class GWPredictions(TypedDict): - """TypedDict of the output given when calling GlobalWorkspaceBase.predict""" + """TypedDict of the output given when calling `GlobalWorkspaceBase.predict`""" demi_cycles: dict[str, torch.Tensor] - """demi_cycle predictions of the model for each domain. Only computed on domain - groups with only one domain.""" + """Demi-cycle predictions of the model for each domain. Only computed on domain + groups with only one domain. + """ cycles: dict[tuple[str, str], torch.Tensor] - """cycle predictions of the model from one domain through another one. Only computed - on domain groups""" + """Cycle predictions of the model from one domain through another one. + Only computed on domain groups with more than one domain. + The keys are tuple with start domain and intermediary domain. + """ + translations: dict[tuple[str, str], torch.Tensor] + """Translation predictions of the model from one domain through another one. + + Only computed on domain groups with more than one domain. + The keys are tuples with start domain and target domain. + """ + states: dict[str, torch.Tensor] + """GW state representation from domain groups with only one domain. + The key represent the domain's name. + """ class GlobalWorkspaceBase(LightningModule): + """Global Workspace Lightning Module. + + This is the base class to build the Global Workspace. + """ + def __init__( self, gw_mod: GWModuleBase, @@ -67,6 +90,19 @@ def __init__( optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, ) -> None: + """Initializes a GW + + Args: + gw_mod (`GWModuleBase`): the GWModule + domain_mods (`Mapping[str, DomainModule]`): mapping of the domains + connected to the GW. Keys are domain names, values are the + `DomainModule`. + loss_mod (`GWLossesBase`): module to compute the GW losses. + optim_lr (`float`): learning rate + optim_weight_decay (`float`): weight decay + scheduler_args (`SchedulerArgs`): `SchedulerArgs` instance to define + scheduler parameters. + """ super().__init__() self.save_hyperparameters( ignore=[ @@ -80,8 +116,11 @@ def __init__( ) self.gw_mod = gw_mod + """ a `GWModuleBase` implementation.""" self.domain_mods = domain_mods + """Mapping of `DomainModule`s.""" self.loss_mod = loss_mod + """The module that computes losses of the GW""" self.optim_lr = optim_lr self.optim_weight_decay = optim_weight_decay @@ -90,29 +129,73 @@ def __init__( self.scheduler_args.update(scheduler_args) @property - def workspace_dim(self): + def workspace_dim(self) -> int: + """Dimension of the GW.""" return self.gw_mod.workspace_dim - def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: + def encode(self, x: LatentsDomainGroupT) -> torch.Tensor: + """Encode latent representations into the GW representation. + + This directly calls `GWModuleBase.encode` and is a convenient proxy to + ```python + self.gw_mod.encode(x) + ``` + + Args: + x (`LatentsDomainGroupT`): the input domain representations. + + Returns: + `torch.Tensor`: the GW representations. + """ return self.gw_mod.encode(x) def decode( self, z: torch.Tensor, domains: Iterable[str] | None = None ) -> dict[str, torch.Tensor]: + """Decode the GW representation into given `domains`. + + This directly calls `GWModuleBase.decode` and is a convenient proxy to + ```python + self.gw_mod.decode(x) + ``` + + Args: + z (`torch.Tensor`): the GW representation. + domains (`Iterable[str]`): iterable of domains to decode. + + Returns: + `dict[str, torch.Tensor]`: the decoded unimodal representations. + """ return self.gw_mod.decode(z, domains) - def forward(self, latent_domains: LatentsT) -> GWPredictions: - outputs = GWPredictions( - **{ - "demi_cycles": self.batch_demi_cycles(latent_domains), - "cycles": self.batch_cycles(latent_domains), - "translations": self.batch_translations(latent_domains), - "states": self.batch_gw_states(latent_domains), - } + def forward(self, latent_domains: LatentsDomainGroupsT) -> GWPredictions: + """Computes demi-cycles, cycles, and translations. + + Args: + latent_domains (`LatentsT`): Groups of domains for the computation. + + Returns: + `GWPredictions`: the predictions on the batch. + """ + + return GWPredictions( + demi_cycles=self.batch_demi_cycles(latent_domains), + cycles=self.batch_cycles(latent_domains), + translations=self.batch_translations(latent_domains), + states=self.batch_gw_states(latent_domains), ) - return outputs - def batch_gw_states(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def batch_gw_states( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: + """Comptues GW states of a batch of groups of domains. + + Args: + latent_domains (`LatentsT`): the batch of groups of domains + + Returns: + `dict[str, torch.Tensor]`: states for each domain. + """ predictions: dict[str, torch.Tensor] = {} for domains, latents in latent_domains.items(): if len(domains) > 1: @@ -122,7 +205,17 @@ def batch_gw_states(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: predictions[domain_name] = z return predictions - def batch_demi_cycles(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def batch_demi_cycles( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: + """Computes demi-cycles of a batch of groups of domains. + + Args: + latent_domains (`LatentsT`): the batch of groups of domains + + Returns: + `dict[str, torch.Tensor]`: demi-cycles predictions for each domain. + """ predictions: dict[str, torch.Tensor] = {} for domains, latents in latent_domains.items(): if len(domains) > 1: @@ -133,8 +226,17 @@ def batch_demi_cycles(self, latent_domains: LatentsT) -> dict[str, torch.Tensor] return predictions def batch_cycles( - self, latent_domains: LatentsT + self, latent_domains: LatentsDomainGroupsT ) -> dict[tuple[str, str], torch.Tensor]: + """Computes cycles of a batch of groups of domains. + + Args: + latent_domains (`LatentsT`): the batch of groups of domains + + Returns: + `dict[tuple[str, str], torch.Tensor]`: cycles predictions for each + couple of (start domain, intermediary domain). + """ predictions: dict[tuple[str, str], torch.Tensor] = {} for domains_source, latents_source in latent_domains.items(): if len(domains_source) > 1: @@ -149,8 +251,17 @@ def batch_cycles( return predictions def batch_translations( - self, latent_domains: LatentsT + self, latent_domains: LatentsDomainGroupsT ) -> dict[tuple[str, str], torch.Tensor]: + """Computes translations of a batch of groups of domains. + + Args: + latent_domains (`LatentsT`): the batch of groups of domains + + Returns: + `dict[tuple[str, str], torch.Tensor]`: translation predictions for each + couple of (start domain, target domain). + """ predictions: dict[tuple[str, str], torch.Tensor] = {} for domains, latents in latent_domains.items(): if len(domains) < 2: @@ -167,12 +278,37 @@ def batch_translations( return predictions def encode_domain(self, domain: Any, name: str) -> torch.Tensor: + """Encodes a domain from the domain data into the unimodal representation. + + This is a convenient proxy for the `DomainModule.encode` method and is + equivalent to: + ```python + self.domain_mods[name].encode(domain) + ``` + + Args: + domain (`Any`): the domain data + name (`str`): domain name to encode + + Returns: + `torch.Tensor`: the domain's unimodal representation. + """ return self.domain_mods[name].encode(domain) def encode_domains( self, - batch: Mapping[frozenset[str], Mapping[str, Any]], - ) -> dict[frozenset[str], dict[str, torch.Tensor]]: + batch: RawDomainGroupsT, + ) -> LatentsDomainGroupsDT: + """Encode all domains in the batch. + + Args: + batch (`RawDomainGroupsT`): the batch of + domain groups with raw unimodal data to encode into groups of latent + representations. + + Returns: + `LatentsDomainGroupsDT`: the domains' unimodal representations. + """ return { domains: { name: self.domain_mods[name].encode(domain) @@ -182,12 +318,37 @@ def encode_domains( } def decode_domain(self, domain: torch.Tensor, name: str) -> Any: + """Decodes a domain from the unimodal representation into the domain data. + + This is a convenient proxy for the `DomainModule.encode` method and is + equivalent to: + ```python + self.domain_mods[name].decode(domain) + ``` + + Args: + domain (`torch.Tensor`): the domain data + name (`str`): domain name to encode + + Returns: + `Any`: the domain's raw data. + """ return self.domain_mods[name].decode(domain) def decode_domains( self, - latents_domain: LatentsT, - ) -> dict[frozenset[str], dict[str, Any]]: + latents_domain: LatentsDomainGroupsT, + ) -> RawDomainGroupsDT: + """Decodes all domains in the batch. + + Args: + batch (`LatentsDomainGroupsT`): the batch of + domain groups with unimodal latent representation to decode into + groups of raw data. + + Returns: + `LatentsDomainGroupsDT`: the domains' raw data. + """ return { domains: { name: self.domain_mods[name].decode(domain) @@ -198,8 +359,16 @@ def decode_domains( def _get_batch_size( self, - domain_latents: LatentsT, + domain_latents: LatentsDomainGroupsT, ) -> int: + """Get the batch size of the batch. + + Args: + domain_latents (`LatentsDomainGroupsT`): the batch of groups. + + Returns: + int: the batch size. + """ for data in domain_latents.values(): for tensor in data.values(): return tensor.size(0) @@ -207,9 +376,19 @@ def _get_batch_size( def generic_step( self, - batch: Mapping[frozenset[str], Mapping[str, Any]], - mode: str, + batch: RawDomainGroupsT, + mode: Literal["train", "val", "test", "val/ood", "test/ood"], ) -> torch.Tensor: + """The generic step used in `training_step`, `validation_step` and + `test_step`. + + Args: + batch (`RawDomainGroupsT`): the batch of groups of raw unimodal data. + mode (`Literal["train", "val", "test", "val/ood", "test/ood"]`): + + Returns: + `torch.Tensor`: the loss to train on. + """ domain_latents = self.encode_domains(batch) batch_size = self._get_batch_size(domain_latents) @@ -226,8 +405,10 @@ def generic_step( return loss_output.loss def validation_step( - self, data: Mapping[str, Any], _, dataloader_idx: int = 0 + self, data: RawDomainGroupT, _, dataloader_idx: int = 0 ) -> torch.Tensor: + """Validation step used by lightning""" + batch = {frozenset(data.keys()): data} for domain in data.keys(): batch[frozenset([domain])] = {domain: data[domain]} @@ -238,6 +419,8 @@ def validation_step( def test_step( self, data: Mapping[str, Any], _, dataloader_idx: int = 0 ) -> torch.Tensor: + """Test step used by lightning""" + batch = {frozenset(data.keys()): data} for domain in data.keys(): batch[frozenset([domain])] = {domain: data[domain]} @@ -248,9 +431,13 @@ def test_step( def training_step( self, batch: Mapping[frozenset[str], Mapping[str, Any]], _ ) -> torch.Tensor: + """Training step used by lightning""" + return self.generic_step(batch, mode="train") def predict_step(self, data: Mapping[str, Any], _) -> GWPredictions: # type: ignore + """Predict step used by lightning""" + batch = {frozenset(data.keys()): data} for domain in data.keys(): batch[frozenset([domain])] = {domain: data[domain]} @@ -259,6 +446,12 @@ def predict_step(self, data: Mapping[str, Any], _) -> GWPredictions: # type: ig return self.forward(domain_latents) def configure_optimizers(self) -> OptimizerLRSchedulerConfig: + """Configure models optimizers. + + Here we use `AdamW` for the optimizer and `OneCycleLR` for the learning-rate + scheduler. + """ + optimizer = torch.optim.AdamW( self.parameters(), lr=self.optim_lr, @@ -279,6 +472,19 @@ def configure_optimizers(self) -> OptimizerLRSchedulerConfig: def freeze_domain_modules( domain_mods: Mapping[str, DomainModule], ) -> dict[str, DomainModule]: + """Freezes weights and set to eval mode the domain modules. + + > [!NOTE] + > The output is casted as `dict[str, DomainModule]` type for better auto-completion, + > but is actually a torch `ModuleDict`. + + Args: + domain_mods (`Mapping[str, DomainModule]`): mapping of domain modules to freeze + + Returns: + `ModuleDict`: frozen modules. + """ + for mod in domain_mods.values(): mod.freeze() # Cast for better auto-completion at the expense of ModuleDict @@ -286,6 +492,12 @@ def freeze_domain_modules( class GlobalWorkspace(GlobalWorkspaceBase): + """A simple 2-domains max flavor of GlobalWorkspaceBase. + + This is used to simplify a Global Workspace instanciation and only overrides the + `__init__` method. + """ + def __init__( self, domain_mods: Mapping[str, DomainModule], @@ -298,6 +510,26 @@ def __init__( learn_logit_scale: bool = False, contrastive_loss: ContrastiveLossType | None = None, ) -> None: + """Initializes a Global Workspace + + Args: + domain_mods (`Mapping[str, DomainModule]`): mapping of the domains + connected to the GW. Keys are domain names, values are the + `DomainModule`. + gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain + name to a `GWInterfaceBase` class which role is to encode/decode + unimodal latent representations into a GW representation (pre fusion). + workspace_dim (`int`): dimension of the GW. + loss_coefs (`LossCoefs`): loss coefficients + optim_lr (`float`): learning rate + optim_weight_decay (`float`): weight decay + scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments + learn_logit_scale (`bool`): whether to learn the contrastive learning + contrastive loss when using the default contrastive loss. + contrastive_loss (`ContrastiveLossType | None`): a contrastive loss + function used for alignment. `learn_logit_scale` will not affect custom + contrastive losses. + """ gw_mod = GWModule(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) if contrastive_loss is None: @@ -322,6 +554,12 @@ def __init__( class VariationalGlobalWorkspace(GlobalWorkspaceBase): + """A simple 2-domains max variational flavor of GlobalWorkspaceBase. + + This is used to simplify a Global Workspace instanciation and only overrides the + `__init__` method. + """ + def __init__( self, domain_mods: Mapping[str, DomainModule], @@ -336,6 +574,31 @@ def __init__( contrastive_loss: ContrastiveLossType | None = None, var_contrastive_loss: VarContrastiveLossType | None = None, ) -> None: + """Initializes a Global Workspace + + Args: + domain_mods (`Mapping[str, DomainModule]`): mapping of the domains + connected to the GW. Keys are domain names, values are the + `DomainModule`. + gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain + name to a `GWInterfaceBase` class which role is to encode/decode + unimodal latent representations into a GW representation (pre fusion). + workspace_dim (`int`): dimension of the GW. + loss_coefs (`LossCoefs`): loss coefficients + use_var_contrastive_loss (`bool`): whether to use the variational + contrastive loss which uses means and log variance for computations. + optim_lr (`float`): learning rate + optim_weight_decay (`float`): weight decay + scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments + learn_logit_scale (`bool`): whether to learn the contrastive learning + contrastive loss when using the default contrastive loss. + contrastive_loss (`ContrastiveLossType | None`): a contrastive loss + function used for alignment. `learn_logit_scale` will not affect custom + contrastive losses. + var_contrastive_loss (`VarContrastiveLossType | None`): a variational + contrastive loss. Only used if `use_var_contrastive_loss` is set to + `True`. + """ gw_mod = VariationalGWModule(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) @@ -373,6 +636,12 @@ def __init__( class GlobalWorkspaceFusion(GlobalWorkspaceBase): + """The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase. + + This is used to simplify a Global Workspace instanciation and only overrides the + `__init__` method. + """ + def __init__( self, domain_mods: Mapping[str, DomainModule], @@ -384,6 +653,25 @@ def __init__( learn_logit_scale: bool = False, contrastive_loss: ContrastiveLossType | None = None, ) -> None: + """Initializes a Global Workspace + + Args: + domain_mods (`Mapping[str, DomainModule]`): mapping of the domains + connected to the GW. Keys are domain names, values are the + `DomainModule`. + gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain + name to a `GWInterfaceBase` class which role is to encode/decode + unimodal latent representations into a GW representation (pre fusion). + workspace_dim (`int`): dimension of the GW. + optim_lr (`float`): learning rate + optim_weight_decay (`float`): weight decay + scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments + learn_logit_scale (`bool`): whether to learn the contrastive learning + contrastive loss when using the default contrastive loss. + contrastive_loss (`ContrastiveLossType | None`): a contrastive loss + function used for alignment. `learn_logit_scale` will not affect custom + contrastive losses. + """ gw_mod = GWModuleFusion(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) @@ -416,6 +704,31 @@ def pretrained_global_workspace( contrastive_fn: ContrastiveLossType, **kwargs, ) -> GlobalWorkspace: + """ + Load a `GlobalWorkspace` flavor of `GlobalWorkspaceBase` from a checkpoint. + + Args: + checkpoint_path (`str | Path`): path to checkpoint + domain_mods (`Mapping[str, DomainModule]`): mapping of the domains + connected to the GW. Keys are domain names, values are the + `DomainModule`. + gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain + name to a `GWInterfaceBase` class which role is to encode/decode + unimodal latent representations into a GW representation (pre fusion). + workspace_dim (`int`): dimension of the GW. + loss_coefs (`LossCoefs`): loss coefficients + contrastive_loss (`ContrastiveLossType`): a contrastive loss + function used for alignment. `learn_logit_scale` will not affect custom + contrastive losses. + **kwargs: additional arguments to pass to + `GlobalWorkspace.load_from_checkpoint`. + + Returns: + `GlobalWorkspace`: the pretrained `GlobalWorkspace`. + + Raises: + `TypeError`: if loaded type is not `GlobalWorkspace`. + """ gw_mod = GWModule(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) loss_mod = GWLosses( diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 6fcd8e82..9cab1dfa 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -6,17 +6,30 @@ from torch import nn from shimmer.modules.domain import DomainModule +from shimmer.modules.losses import LatentsDomainGroupT from shimmer.modules.vae import reparameterize -def get_n_layers(n_layers: int, hidden_dim: int): - layers = [] +def get_n_layers(n_layers: int, hidden_dim: int) -> list[nn.Module]: + """Makes a list of `n_layers` `nn.Linear` layers with `nn.ReLU`. + + Args: + n_layers (`int`): number of layers + hidden_dim (`int`): size of the hidden dimension + + Returns: + `list[nn.Module]`: list of linear and relu layers. + """ + + layers: list[nn.Module] = [] for _ in range(n_layers): layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.ReLU()]) return layers class GWDecoder(nn.Sequential): + """A Decoder network used in GWInterfaces.""" + def __init__( self, in_dim: int, @@ -24,11 +37,28 @@ def __init__( out_dim: int, n_layers: int, ): + """Initializes the decoder. + + Args: + in_dim (`int`): input dimension + hidden_dim (`int`): hidden dimension + out_dim (`int`): output dimension + n_layers (`int`): number of hidden layers. The total number of layers + will be `n_layers` + 2 (one before, one after). + """ + self.in_dim = in_dim + """input dimension""" + self.hidden_dim = hidden_dim + """hidden dimension""" + self.out_dim = out_dim + """output dimension""" self.n_layers = n_layers + """number of hidden layers. The total number of layers + will be `n_layers` + 2 (one before, one after).""" super().__init__( nn.Linear(self.in_dim, self.hidden_dim), @@ -39,6 +69,11 @@ def __init__( class GWEncoder(GWDecoder): + """An Encoder network used in GWInterfaces. + + This is similar to the decoder, but adds a tanh non-linearity at the end. + """ + def __init__( self, in_dim: int, @@ -46,6 +81,15 @@ def __init__( out_dim: int, n_layers: int, ): + """Initializes the encoder. + + Args: + in_dim (`int`): input dimension + hidden_dim (`int`): hidden dimension + out_dim (`int`): output dimension + n_layers (`int`): number of hidden layers. The total number of layers + will be `n_layers` + 2 (one before, one after). + """ super().__init__(in_dim, hidden_dim, out_dim, n_layers) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -53,6 +97,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class VariationalGWEncoder(nn.Module): + """A Variational flavor of encoder network used in GWInterfaces.""" + def __init__( self, in_dim: int, @@ -60,12 +106,29 @@ def __init__( out_dim: int, n_layers: int, ): + """Initializes the encoder. + + Args: + in_dim (`int`): input dimension + hidden_dim (`int`): hidden dimension + out_dim (`int`): output dimension + n_layers (`int`): number of hidden layers. The total number of layers + will be `n_layers` + 2 (one before, one after). + """ super().__init__() self.in_dim = in_dim + """input dimension""" + self.hidden_dim = hidden_dim + """hidden dimension""" + self.out_dim = out_dim + """output dimension""" + self.n_layers = n_layers + """number of hidden layers. The total number of layers + will be `n_layers` + 2 (one before, one after).""" self.layers = nn.Sequential( nn.Linear(self.in_dim, self.hidden_dim), @@ -74,6 +137,7 @@ def __init__( nn.Linear(self.hidden_dim, self.out_dim), nn.Tanh(), ) + self.uncertainty_level = nn.Parameter(torch.full((self.out_dim,), 3.0)) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -81,28 +145,90 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: class GWInterfaceBase(nn.Module, ABC): + """Base class for GWInterfaces. + + Interfaces encode and decode unimodal representation to the domain's GW + representation (pre-fusion). + + This is an abstract class and should be implemented. + For an implemented interface, see `GWInterface`. + """ + def __init__(self, domain_module: DomainModule, workspace_dim: int) -> None: + """ + Initialized the interface. + + Args: + domain_module (`DomainModule`): Domain module to link. + workspace_dim (`int`): dimension of the GW. + """ super().__init__() + self.domain_module = domain_module + """Domain module.""" + self.workspace_dim = workspace_dim + """Dimension of the GW.""" @abstractmethod - def encode(self, x: torch.Tensor) -> torch.Tensor: ... + def encode(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode from the unimodal latent representation to the domain's GW + representation (pre-fusion). + + Args: + x (`torch.Tensor`): the domain's unimodal latent representation. + + Returns: + `torch.Tensor`: the domain's pre-fusion GW representation. + """ + ... @abstractmethod - def decode(self, z: torch.Tensor) -> torch.Tensor: ... + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Decode from the domain's pre-fusion GW to the unimodal latent representation. + + Args: + z (`torch.Tensor`): the domain's pre-fusion GW representation. + + Returns: + `torch.Tensor`: the domain's unimodal latent representation. + """ + ... class GWModuleBase(nn.Module, ABC): + """Base class for GWModule. + + GWModule handle how to merge representations from the Interfaces and define + some common operations in GW like cycles and translations. + + This is an abstract class and should be implemented. + For an implemented interface, see `GWModule`. + """ + def __init__( self, gw_interfaces: Mapping[str, GWInterfaceBase], workspace_dim: int ) -> None: + """Initializes the GWModule. + + Args: + gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain + name to a `GWInterfaceBase` class which role is to encode/decode + unimodal latent representations into a GW representation (pre fusion). + workspace_dim (`int`): dimension of the GW. + """ super().__init__() + # casting for LSP autocompletion self.gw_interfaces = cast( dict[str, GWInterfaceBase], nn.ModuleDict(gw_interfaces) ) + """the GWInterface""" + self.workspace_dim = workspace_dim + """dimension of the GW""" def on_before_gw_encode_dcy( self, x: Mapping[str, torch.Tensor] @@ -110,10 +236,12 @@ def on_before_gw_encode_dcy( """ Callback used before projecting the unimodal representations to the GW representation when computing the demi-cycle loss. Defaults to identity. + Args: - x: mapping of domain name to latent representation. + x (`Mapping[str, torch.Tensor]`): mapping of domain name to + latent representation. Returns: - the same mapping with updated representations + `dict[str, torch.Tensor]`: the same mapping with updated representations """ return { domain: self.gw_interfaces[domain].domain_module.on_before_gw_encode_dcy( @@ -128,10 +256,12 @@ def on_before_gw_encode_cy( """ Callback used before projecting the unimodal representations to the GW representation when computing the cycle loss. Defaults to identity. + Args: - x: mapping of domain name to latent representation. + x (`Mapping[str, torch.Tensor]`): mapping of domain name to + latent representation. Returns: - the same mapping with updated representations + `dict[str, torch.Tensor]`: the same mapping with updated representations """ return { domain: self.gw_interfaces[domain].domain_module.on_before_gw_encode_cy( @@ -146,10 +276,12 @@ def on_before_gw_encode_tr( """ Callback used before projecting the unimodal representations to the GW representation when computing the translation loss. Defaults to identity. + Args: - x: mapping of domain name to latent representation. + x (`Mapping[str, torch.Tensor]`): mapping of domain name to + latent representation. Returns: - the same mapping with updated representations + `dict[str, torch.Tensor]`: the same mapping with updated representations """ return { domain: self.gw_interfaces[domain].domain_module.on_before_gw_encode_tr( @@ -164,10 +296,12 @@ def on_before_gw_encode_cont( """ Callback used before projecting the unimodal representations to the GW representation when computing the contrastive loss. Defaults to identity. + Args: - x: mapping of domain name to latent representation. + x (`Mapping[str, torch.Tensor]`): mapping of domain name to + latent representation. Returns: - the same mapping with updated representations + `dict[str, torch.Tensor]`: the same mapping with updated representations """ return { domain: self.gw_interfaces[domain].domain_module.on_before_gw_encode_cont( @@ -177,13 +311,14 @@ def on_before_gw_encode_cont( } @abstractmethod - def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: - """ - Encode the unimodal representations to the GW representation. + def encode(self, x: LatentsDomainGroupT) -> torch.Tensor: + """Encode latent representations into the GW representation. + Args: - x: mapping of domain name to unimodal representation. + x (`LatentsDomainGroupT`): the input domain representations. + Returns: - GW representation + `torch.Tensor`: the GW representations. """ ... @@ -191,44 +326,52 @@ def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: def decode( self, z: torch.Tensor, domains: Iterable[str] | None = None ) -> dict[str, torch.Tensor]: - """ - Decode the GW representation to the unimodal representations. + """Decode the GW representation into given `domains`. + Args: - z: GW representation - domains: iterable of domains to decode to. Defaults to all domains. + z (`torch.Tensor`): the GW representation. + domains (`Iterable[str]`): iterable of domains to decode. + Returns: - dict of domain name to decoded unimodal representation. + `dict[str, torch.Tensor]`: the decoded unimodal representations. """ ... @abstractmethod - def translate(self, x: Mapping[str, torch.Tensor], to: str) -> torch.Tensor: + def translate(self, x: LatentsDomainGroupT, to: str) -> torch.Tensor: """ Translate from one domain to another. + Args: - x: mapping of domain name to unimodal representation. - to: domain to translate to. + x (`LatentsDomainGroupT`): mapping of domain name + to unimodal representation. + to (`str`): domain to translate to. Returns: - the unimodal representation of domain given by `to`. + `torch.Tensor`: the unimodal representation of domain given by `to`. """ ... @abstractmethod - def cycle( - self, x: Mapping[str, torch.Tensor], through: str - ) -> dict[str, torch.Tensor]: + def cycle(self, x: LatentsDomainGroupT, through: str) -> dict[str, torch.Tensor]: """ Cycle from one domain through another. + Args: - x: mapping of domain name to unimodal representation. - through: domain to translate to. + x (`LatentsDomainGroupT`): mapping of domain name + to unimodal representation. + through (`str`): intermediate domain of the cycle + Returns: - the unimodal representations cycles through the given domain. + `torch.Tensor`: the unimodal representation of domain given by `to`. """ ... class GWInterface(GWInterfaceBase): + """ + A implementation of `GWInterfaceBase` using `GWEncoder` and `GWDecoder`. + """ + def __init__( self, domain_module: DomainModule, diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 16b19791..88d835e4 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import TypedDict +from typing import Any, TypedDict import torch import torch.nn.functional as F @@ -10,8 +10,46 @@ from shimmer.modules.gw_module import GWModule, GWModuleBase, VariationalGWModule from shimmer.modules.vae import kl_divergence_loss +RawDomainGroupT = Mapping[str, Any] +"""Matched raw unimodal data from multiple domains. +Keys of the mapping are domains names.""" + +RawDomainGroupDT = dict[str, Any] +"""Matched raw unimodal data from multiple domains. +Keys of the dict are domains names. + +This is a more specific version of `RawDomainGroupT` used in method's outputs.""" + LatentsDomainGroupT = Mapping[str, torch.Tensor] -LatentsT = Mapping[frozenset[str], LatentsDomainGroupT] +"""Matched unimodal latent representations from multiple domains. +Keys of the mapping are domains names.""" + +LatentsDomainGroupDT = dict[str, torch.Tensor] +"""Matched unimodal latent representations from multiple domains. +Keys of the dict are domains names. + +This is a more specific version of `LatentsDomainGroupT` used in method's outputs.""" + +LatentsDomainGroupsT = Mapping[frozenset[str], LatentsDomainGroupT] +"""Mapping of `LatentsDomainGroupT`. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired).""" + +LatentsDomainGroupsDT = dict[frozenset[str], LatentsDomainGroupDT] +"""Mapping of `LatentsDomainGroupDT`. +Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired). + +This is a more specific version of `LatentsDomainGroupsT` used in method's outputs.""" + +RawDomainGroupsT = Mapping[frozenset[str], RawDomainGroupT] +"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired).""" + +RawDomainGroupsDT = dict[frozenset[str], RawDomainGroupDT] +"""Mapping of `RawDomainGroupT`. Keys are frozenset of domains matched in the group. +Each group is independent and contains different data (unpaired). + +This is a more specific version of `RawDomainGroupsT` used in method's outputs.""" class GWLossesBase(torch.nn.Module, ABC): @@ -22,7 +60,7 @@ class GWLossesBase(torch.nn.Module, ABC): """ @abstractmethod - def step(self, domain_latents: LatentsT, mode: str) -> LossOutput: + def step(self, domain_latents: LatentsDomainGroupsT, mode: str) -> LossOutput: """ Computes the losses Args: @@ -36,7 +74,7 @@ def step(self, domain_latents: LatentsT, mode: str) -> LossOutput: def _demi_cycle_loss( gw_mod: GWModuleBase, domain_mods: dict[str, DomainModule], - latent_domains: LatentsT, + latent_domains: LatentsDomainGroupsT, ) -> dict[str, torch.Tensor]: losses: dict[str, torch.Tensor] = {} metrics: dict[str, torch.Tensor] = {} @@ -62,7 +100,7 @@ def _demi_cycle_loss( def _cycle_loss( gw_mod: GWModuleBase, domain_mods: dict[str, DomainModule], - latent_domains: LatentsT, + latent_domains: LatentsDomainGroupsT, ) -> dict[str, torch.Tensor]: losses: dict[str, torch.Tensor] = {} metrics: dict[str, torch.Tensor] = {} @@ -99,7 +137,7 @@ def _cycle_loss( def _translation_loss( gw_mod: GWModuleBase, domain_mods: dict[str, DomainModule], - latent_domains: LatentsT, + latent_domains: LatentsDomainGroupsT, ) -> dict[str, torch.Tensor]: losses: dict[str, torch.Tensor] = {} metrics: dict[str, torch.Tensor] = {} @@ -143,7 +181,7 @@ def _translation_loss( def _contrastive_loss( gw_mod: GWModuleBase, - latent_domains: LatentsT, + latent_domains: LatentsDomainGroupsT, contrastive_fn: ContrastiveLossType, ) -> dict[str, torch.Tensor]: losses: dict[str, torch.Tensor] = {} @@ -179,7 +217,7 @@ def _contrastive_loss( def _contrastive_loss_with_uncertainty( gw_mod: VariationalGWModule, - latent_domains: LatentsT, + latent_domains: LatentsDomainGroupsT, contrastive_fn: VarContrastiveLossType, ) -> dict[str, torch.Tensor]: losses: dict[str, torch.Tensor] = {} @@ -259,21 +297,27 @@ def __init__( self.loss_coefs = loss_coefs self.contrastive_fn = contrastive_fn - def demi_cycle_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def demi_cycle_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _demi_cycle_loss(self.gw_mod, self.domain_mods, latent_domains) - def cycle_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def cycle_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _cycle_loss(self.gw_mod, self.domain_mods, latent_domains) - def translation_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def translation_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _translation_loss(self.gw_mod, self.domain_mods, latent_domains) - def contrastive_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def contrastive_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) - def step( - self, domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]], _ - ) -> LossOutput: + def step(self, domain_latents: LatentsDomainGroupsT, _) -> LossOutput: metrics: dict[str, torch.Tensor] = {} metrics.update(self.demi_cycle_loss(domain_latents)) @@ -329,16 +373,24 @@ def __init__( self.contrastive_fn = contrastive_fn self.var_contrastive_fn = var_contrastive_fn - def demi_cycle_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def demi_cycle_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _demi_cycle_loss(self.gw_mod, self.domain_mods, latent_domains) - def cycle_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def cycle_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _cycle_loss(self.gw_mod, self.domain_mods, latent_domains) - def translation_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def translation_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _translation_loss(self.gw_mod, self.domain_mods, latent_domains) - def contrastive_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def contrastive_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: if self.var_contrastive_fn is not None: return _contrastive_loss_with_uncertainty( self.gw_mod, latent_domains, self.var_contrastive_fn @@ -347,7 +399,7 @@ def contrastive_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: assert self.contrastive_fn is not None return _contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) - def kl_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def kl_loss(self, latent_domains: LatentsDomainGroupsT) -> dict[str, torch.Tensor]: losses: dict[str, torch.Tensor] = {} for domains, latents in latent_domains.items(): @@ -365,9 +417,7 @@ def kl_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: losses["kl"] = torch.stack(list(losses.values()), dim=0).mean() return losses - def step( - self, domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]], _ - ) -> LossOutput: + def step(self, domain_latents: LatentsDomainGroupsT, _) -> LossOutput: metrics: dict[str, torch.Tensor] = {} dcy_losses = self.demi_cycle_loss(domain_latents) @@ -454,20 +504,28 @@ def __init__( self.domain_mods = domain_mods self.contrastive_fn = contrastive_fn - def demi_cycle_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def demi_cycle_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _demi_cycle_loss(self.gw_mod, self.domain_mods, latent_domains) - def cycle_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def cycle_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _cycle_loss(self.gw_mod, self.domain_mods, latent_domains) - def translation_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def translation_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _translation_loss(self.gw_mod, self.domain_mods, latent_domains) - def contrastive_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + def contrastive_loss( + self, latent_domains: LatentsDomainGroupsT + ) -> dict[str, torch.Tensor]: return _contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) def broadcast_loss( - self, latent_domains: LatentsT, mode: str + self, latent_domains: LatentsDomainGroupsT, mode: str ) -> dict[str, torch.Tensor]: losses: dict[str, torch.Tensor] = {} metrics: dict[str, torch.Tensor] = {} @@ -548,7 +606,7 @@ def broadcast_loss( def step( self, - domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]], + domain_latents: LatentsDomainGroupsT, mode: str, ) -> LossOutput: metrics: dict[str, torch.Tensor] = {}