Skip to content

Commit

Permalink
Use ABCMeta for abstract classes. Fixes #4.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Jan 19, 2024
1 parent 663a525 commit d3b1497
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 31 deletions.
69 changes: 64 additions & 5 deletions shimmer/modules/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,32 @@

import lightning.pytorch as pl
import torch
from torch.nn.functional import mse_loss


class DomainModule(pl.LightningModule):
"""
Base class for a DomainModule.
We do not use ABCMeta here because some modules could be without encore or decoder.
"""

def encode(self, x: Any) -> torch.Tensor:
"""
Encode data to the unimodal representation.
Args:
x: data of the domain.
Returns:
a unimodal representation.
"""
raise NotImplementedError

def decode(self, z: torch.Tensor) -> Any:
"""
Decode data back to the unimodal representation.
Args:
x: data of the domain.
Returns:
a unimodal representation.
"""
raise NotImplementedError

def on_before_gw_encode_dcy(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -22,27 +43,65 @@ def on_before_gw_encode_tr(self, x: torch.Tensor) -> torch.Tensor:
def on_before_gw_encode_cy(self, x: torch.Tensor) -> torch.Tensor:
return x

def decode(self, z: torch.Tensor) -> Any:
raise NotImplementedError

def compute_loss(
self, pred: torch.Tensor, target: torch.Tensor
) -> dict[str, torch.Tensor]:
return {"loss": mse_loss(pred, target)}
"""
Computes the loss of the modality. If you implement compute_dcy_loss,
compute_cy_loss and compute_tr_loss independently, no need to define this
function.
Args:
pred: tensor with a predicted latent unimodal representation
target: target tensor
Results:
Dict of losses. Must contain the "loss" key with the total loss
used for training. Any other key will be logged, but not trained on.
"""
raise NotImplementedError

def compute_dcy_loss(
self, pred: torch.Tensor, target: torch.Tensor
) -> dict[str, torch.Tensor]:
"""
Computes the loss for a demi-cycle. Override if the demi-cycle loss is
different that the generic loss.
Args:
pred: tensor with a predicted latent unimodal representation
target: target tensor
Results:
Dict of losses. Must contain the "loss" key with the total loss
used for training. Any other key will be logged, but not trained on.
"""
return self.compute_loss(pred, target)

def compute_cy_loss(
self, pred: torch.Tensor, target: torch.Tensor
) -> dict[str, torch.Tensor]:
"""
Computes the loss for a cycle. Override if the cycle loss is
different that the generic loss.
Args:
pred: tensor with a predicted latent unimodal representation
target: target tensor
Results:
Dict of losses. Must contain the "loss" key with the total loss
used for training. Any other key will be logged, but not trained on.
"""
return self.compute_loss(pred, target)

def compute_tr_loss(
self, pred: torch.Tensor, target: torch.Tensor
) -> dict[str, torch.Tensor]:
"""
Computes the loss for a translation. Override if the translation loss is
different that the generic loss.
Args:
pred: tensor with a predicted latent unimodal representation
target: target tensor
Results:
Dict of losses. Must contain the "loss" key with the total loss
used for training. Any other key will be logged, but not trained on.
"""
return self.compute_loss(pred, target)


Expand Down
94 changes: 83 additions & 11 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABCMeta, abstractmethod
from collections.abc import Iterable, Mapping

import torch
Expand Down Expand Up @@ -78,16 +79,21 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return self.layers(x), self.uncertainty_level.expand(x.size(0), -1)


class GWModule(nn.Module):
class GWModule(nn.Module, metaclass=ABCMeta):
domain_descr: Mapping[str, DomainDescription]
latent_dim: int

def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
raise NotImplementedError

def on_before_gw_encode_dcy(
self, x: Mapping[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""
Callback used before projecting the unimodal representations to the GW
representation when computing the demi-cycle loss. Defaults to identity.
Args:
x: mapping of domain name to latent representation.
Returns:
the same mapping with updated representations
"""
return {
domain: self.domain_descr[domain].module.on_before_gw_encode_dcy(
x[domain]
Expand All @@ -98,6 +104,14 @@ def on_before_gw_encode_dcy(
def on_before_gw_encode_cy(
self, x: Mapping[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""
Callback used before projecting the unimodal representations to the GW
representation when computing the cycle loss. Defaults to identity.
Args:
x: mapping of domain name to latent representation.
Returns:
the same mapping with updated representations
"""
return {
domain: self.domain_descr[domain].module.on_before_gw_encode_cy(
x[domain]
Expand All @@ -108,6 +122,14 @@ def on_before_gw_encode_cy(
def on_before_gw_encode_tr(
self, x: Mapping[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""
Callback used before projecting the unimodal representations to the GW
representation when computing the translation loss. Defaults to identity.
Args:
x: mapping of domain name to latent representation.
Returns:
the same mapping with updated representations
"""
return {
domain: self.domain_descr[domain].module.on_before_gw_encode_tr(
x[domain]
Expand All @@ -118,30 +140,73 @@ def on_before_gw_encode_tr(
def on_before_gw_encode_cont(
self, x: Mapping[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""
Callback used before projecting the unimodal representations to the GW
representation when computing the contrastive loss. Defaults to identity.
Args:
x: mapping of domain name to latent representation.
Returns:
the same mapping with updated representations
"""
return {
domain: self.domain_descr[domain].module.on_before_gw_encode_cont(
x[domain]
)
for domain in x.keys()
}

@abstractmethod
def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
raise NotImplementedError

"""
Encode the unimodal representations to the GW representation.
Args:
x: mapping of domain name to unimodal representation.
Returns:
GW representation
"""
...

@abstractmethod
def decode(
self, z: torch.Tensor, domains: Iterable[str] | None = None
) -> dict[str, torch.Tensor]:
raise NotImplementedError

"""
Decode the GW representation to the unimodal representations.
Args:
z: GW representation
domains: iterable of domains to decode to. Defaults to all domains.
Returns:
dict of domain name to decoded unimodal representation.
"""
...

@abstractmethod
def translate(
self, x: Mapping[str, torch.Tensor], to: str
) -> torch.Tensor:
raise NotImplementedError

"""
Translate from one domain to another.
Args:
x: mapping of domain name to unimodal representation.
to: domain to translate to.
Returns:
the unimodal representation of domain given by `to`.
"""
...

@abstractmethod
def cycle(
self, x: Mapping[str, torch.Tensor], through: str
) -> dict[str, torch.Tensor]:
raise NotImplementedError
"""
Cycle from one domain through another.
Args:
x: mapping of domain name to unimodal representation.
through: domain to translate to.
Returns:
the unimodal representations cycles through the given domain.
"""
...


def default_encoders(
Expand Down Expand Up @@ -194,6 +259,13 @@ def __init__(
)

def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
"""
Merge function used to combine domains.
Args:
x: mapping of domain name to latent representation.
Returns:
The merged representation
"""
return torch.mean(torch.stack(list(x.values())), dim=0)

def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
Expand Down
27 changes: 18 additions & 9 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABCMeta, abstractmethod
from collections.abc import Mapping
from typing import Literal

Expand Down Expand Up @@ -64,18 +65,26 @@ def contrastive_loss_with_uncertainty(
return 0.5 * (ce + ce_t)


class GWLosses(torch.nn.Module):
def step(
self,
domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
) -> dict[str, torch.Tensor]:
raise NotImplementedError
class GWLosses(torch.nn.Module, metaclass=ABCMeta):
"""
Base Abstract Class for Global Workspace (GW) losses. This module is used
to compute the different losses of the GW (typically translation, cycle,
demi-cycle, contrastive losses).
"""

def domain_metrics(
@abstractmethod
def step(
self,
domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
domain_latents: LatentsT,
) -> dict[str, torch.Tensor]:
raise NotImplementedError
"""
Computes the losses
Args:
domain_latents: All latent groups
Returns:
a dict with loss name as keys and loss value as values.
"""
...


def _demi_cycle_loss(
Expand Down
37 changes: 31 additions & 6 deletions shimmer/modules/vae.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from abc import ABCMeta, abstractmethod
from collections.abc import Sequence

import torch
Expand Down Expand Up @@ -28,16 +29,40 @@ def gaussian_nll(
)


class VAEEncoder(nn.Module):
class VAEEncoder(nn.Module, metaclass=ABCMeta):
"""
Base class for a VAE encoder.
"""

@abstractmethod
def forward(
self, x: Sequence[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError


class VAEDecoder(nn.Module):
"""
Encode representation with VAE
Args:
x: sequence of tensors
Retunrs:
tuple with the mean and the log variance
"""
...


class VAEDecoder(nn.Module, metaclass=ABCMeta):
"""
Base class for a VAE decoder.
"""

@abstractmethod
def forward(self, x: torch.Tensor) -> Sequence[torch.Tensor]:
raise NotImplementedError
"""
Decode representation with VAE
Args:
x: representation
Retunrs:
Sequence of tensors reconstructing input
"""
...


class VAE(nn.Module):
Expand Down

0 comments on commit d3b1497

Please sign in to comment.