From d3b1497521a9ef0ef3c17ff383c724ee8197166d Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 19 Jan 2024 15:06:00 +0000 Subject: [PATCH] Use ABCMeta for abstract classes. Fixes #4. --- shimmer/modules/domain.py | 69 ++++++++++++++++++++++++-- shimmer/modules/gw_module.py | 94 +++++++++++++++++++++++++++++++----- shimmer/modules/losses.py | 27 +++++++---- shimmer/modules/vae.py | 37 +++++++++++--- 4 files changed, 196 insertions(+), 31 deletions(-) diff --git a/shimmer/modules/domain.py b/shimmer/modules/domain.py index 9d261ed8..8870fc16 100644 --- a/shimmer/modules/domain.py +++ b/shimmer/modules/domain.py @@ -3,11 +3,32 @@ import lightning.pytorch as pl import torch -from torch.nn.functional import mse_loss class DomainModule(pl.LightningModule): + """ + Base class for a DomainModule. + We do not use ABCMeta here because some modules could be without encore or decoder. + """ + def encode(self, x: Any) -> torch.Tensor: + """ + Encode data to the unimodal representation. + Args: + x: data of the domain. + Returns: + a unimodal representation. + """ + raise NotImplementedError + + def decode(self, z: torch.Tensor) -> Any: + """ + Decode data back to the unimodal representation. + Args: + x: data of the domain. + Returns: + a unimodal representation. + """ raise NotImplementedError def on_before_gw_encode_dcy(self, x: torch.Tensor) -> torch.Tensor: @@ -22,27 +43,65 @@ def on_before_gw_encode_tr(self, x: torch.Tensor) -> torch.Tensor: def on_before_gw_encode_cy(self, x: torch.Tensor) -> torch.Tensor: return x - def decode(self, z: torch.Tensor) -> Any: - raise NotImplementedError - def compute_loss( self, pred: torch.Tensor, target: torch.Tensor ) -> dict[str, torch.Tensor]: - return {"loss": mse_loss(pred, target)} + """ + Computes the loss of the modality. If you implement compute_dcy_loss, + compute_cy_loss and compute_tr_loss independently, no need to define this + function. + Args: + pred: tensor with a predicted latent unimodal representation + target: target tensor + Results: + Dict of losses. Must contain the "loss" key with the total loss + used for training. Any other key will be logged, but not trained on. + """ + raise NotImplementedError def compute_dcy_loss( self, pred: torch.Tensor, target: torch.Tensor ) -> dict[str, torch.Tensor]: + """ + Computes the loss for a demi-cycle. Override if the demi-cycle loss is + different that the generic loss. + Args: + pred: tensor with a predicted latent unimodal representation + target: target tensor + Results: + Dict of losses. Must contain the "loss" key with the total loss + used for training. Any other key will be logged, but not trained on. + """ return self.compute_loss(pred, target) def compute_cy_loss( self, pred: torch.Tensor, target: torch.Tensor ) -> dict[str, torch.Tensor]: + """ + Computes the loss for a cycle. Override if the cycle loss is + different that the generic loss. + Args: + pred: tensor with a predicted latent unimodal representation + target: target tensor + Results: + Dict of losses. Must contain the "loss" key with the total loss + used for training. Any other key will be logged, but not trained on. + """ return self.compute_loss(pred, target) def compute_tr_loss( self, pred: torch.Tensor, target: torch.Tensor ) -> dict[str, torch.Tensor]: + """ + Computes the loss for a translation. Override if the translation loss is + different that the generic loss. + Args: + pred: tensor with a predicted latent unimodal representation + target: target tensor + Results: + Dict of losses. Must contain the "loss" key with the total loss + used for training. Any other key will be logged, but not trained on. + """ return self.compute_loss(pred, target) diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index e57dc67f..287f12bd 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -1,3 +1,4 @@ +from abc import ABCMeta, abstractmethod from collections.abc import Iterable, Mapping import torch @@ -78,16 +79,21 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return self.layers(x), self.uncertainty_level.expand(x.size(0), -1) -class GWModule(nn.Module): +class GWModule(nn.Module, metaclass=ABCMeta): domain_descr: Mapping[str, DomainDescription] latent_dim: int - def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: - raise NotImplementedError - def on_before_gw_encode_dcy( self, x: Mapping[str, torch.Tensor] ) -> dict[str, torch.Tensor]: + """ + Callback used before projecting the unimodal representations to the GW + representation when computing the demi-cycle loss. Defaults to identity. + Args: + x: mapping of domain name to latent representation. + Returns: + the same mapping with updated representations + """ return { domain: self.domain_descr[domain].module.on_before_gw_encode_dcy( x[domain] @@ -98,6 +104,14 @@ def on_before_gw_encode_dcy( def on_before_gw_encode_cy( self, x: Mapping[str, torch.Tensor] ) -> dict[str, torch.Tensor]: + """ + Callback used before projecting the unimodal representations to the GW + representation when computing the cycle loss. Defaults to identity. + Args: + x: mapping of domain name to latent representation. + Returns: + the same mapping with updated representations + """ return { domain: self.domain_descr[domain].module.on_before_gw_encode_cy( x[domain] @@ -108,6 +122,14 @@ def on_before_gw_encode_cy( def on_before_gw_encode_tr( self, x: Mapping[str, torch.Tensor] ) -> dict[str, torch.Tensor]: + """ + Callback used before projecting the unimodal representations to the GW + representation when computing the translation loss. Defaults to identity. + Args: + x: mapping of domain name to latent representation. + Returns: + the same mapping with updated representations + """ return { domain: self.domain_descr[domain].module.on_before_gw_encode_tr( x[domain] @@ -118,6 +140,14 @@ def on_before_gw_encode_tr( def on_before_gw_encode_cont( self, x: Mapping[str, torch.Tensor] ) -> dict[str, torch.Tensor]: + """ + Callback used before projecting the unimodal representations to the GW + representation when computing the contrastive loss. Defaults to identity. + Args: + x: mapping of domain name to latent representation. + Returns: + the same mapping with updated representations + """ return { domain: self.domain_descr[domain].module.on_before_gw_encode_cont( x[domain] @@ -125,23 +155,58 @@ def on_before_gw_encode_cont( for domain in x.keys() } + @abstractmethod def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: - raise NotImplementedError - + """ + Encode the unimodal representations to the GW representation. + Args: + x: mapping of domain name to unimodal representation. + Returns: + GW representation + """ + ... + + @abstractmethod def decode( self, z: torch.Tensor, domains: Iterable[str] | None = None ) -> dict[str, torch.Tensor]: - raise NotImplementedError - + """ + Decode the GW representation to the unimodal representations. + Args: + z: GW representation + domains: iterable of domains to decode to. Defaults to all domains. + Returns: + dict of domain name to decoded unimodal representation. + """ + ... + + @abstractmethod def translate( self, x: Mapping[str, torch.Tensor], to: str ) -> torch.Tensor: - raise NotImplementedError - + """ + Translate from one domain to another. + Args: + x: mapping of domain name to unimodal representation. + to: domain to translate to. + Returns: + the unimodal representation of domain given by `to`. + """ + ... + + @abstractmethod def cycle( self, x: Mapping[str, torch.Tensor], through: str ) -> dict[str, torch.Tensor]: - raise NotImplementedError + """ + Cycle from one domain through another. + Args: + x: mapping of domain name to unimodal representation. + through: domain to translate to. + Returns: + the unimodal representations cycles through the given domain. + """ + ... def default_encoders( @@ -194,6 +259,13 @@ def __init__( ) def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: + """ + Merge function used to combine domains. + Args: + x: mapping of domain name to latent representation. + Returns: + The merged representation + """ return torch.mean(torch.stack(list(x.values())), dim=0) def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 53a9a5ec..3541fc5b 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -1,3 +1,4 @@ +from abc import ABCMeta, abstractmethod from collections.abc import Mapping from typing import Literal @@ -64,18 +65,26 @@ def contrastive_loss_with_uncertainty( return 0.5 * (ce + ce_t) -class GWLosses(torch.nn.Module): - def step( - self, - domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]], - ) -> dict[str, torch.Tensor]: - raise NotImplementedError +class GWLosses(torch.nn.Module, metaclass=ABCMeta): + """ + Base Abstract Class for Global Workspace (GW) losses. This module is used + to compute the different losses of the GW (typically translation, cycle, + demi-cycle, contrastive losses). + """ - def domain_metrics( + @abstractmethod + def step( self, - domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]], + domain_latents: LatentsT, ) -> dict[str, torch.Tensor]: - raise NotImplementedError + """ + Computes the losses + Args: + domain_latents: All latent groups + Returns: + a dict with loss name as keys and loss value as values. + """ + ... def _demi_cycle_loss( diff --git a/shimmer/modules/vae.py b/shimmer/modules/vae.py index 69ca21e1..fe54a93e 100644 --- a/shimmer/modules/vae.py +++ b/shimmer/modules/vae.py @@ -1,4 +1,5 @@ import math +from abc import ABCMeta, abstractmethod from collections.abc import Sequence import torch @@ -28,16 +29,40 @@ def gaussian_nll( ) -class VAEEncoder(nn.Module): +class VAEEncoder(nn.Module, metaclass=ABCMeta): + """ + Base class for a VAE encoder. + """ + + @abstractmethod def forward( self, x: Sequence[torch.Tensor] ) -> tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError - - -class VAEDecoder(nn.Module): + """ + Encode representation with VAE + Args: + x: sequence of tensors + Retunrs: + tuple with the mean and the log variance + """ + ... + + +class VAEDecoder(nn.Module, metaclass=ABCMeta): + """ + Base class for a VAE decoder. + """ + + @abstractmethod def forward(self, x: torch.Tensor) -> Sequence[torch.Tensor]: - raise NotImplementedError + """ + Decode representation with VAE + Args: + x: representation + Retunrs: + Sequence of tensors reconstructing input + """ + ... class VAE(nn.Module):