Skip to content

Commit

Permalink
Remove Interfaces and replace it with gw_encoders and gw_decoders
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 6, 2024
1 parent 0dc8d1f commit d77608f
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 246 deletions.
6 changes: 0 additions & 6 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
from shimmer.modules.gw_module import (
GWDecoder,
GWEncoder,
GWInterface,
GWInterfaceBase,
GWModule,
GWModuleBase,
VariationalGWEncoder,
VariationalGWInterface,
VariationalGWModule,
)
from shimmer.modules.losses import (
Expand Down Expand Up @@ -69,9 +66,6 @@
"GWDecoder",
"GWEncoder",
"VariationalGWEncoder",
"GWInterfaceBase",
"GWInterface",
"VariationalGWInterface",
"GWModuleBase",
"GWModule",
"VariationalGWModule",
Expand Down
6 changes: 0 additions & 6 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
from shimmer.modules.gw_module import (
GWDecoder,
GWEncoder,
GWInterface,
GWInterfaceBase,
GWModule,
GWModuleBase,
VariationalGWEncoder,
VariationalGWInterface,
VariationalGWModule,
)
from shimmer.modules.losses import (
Expand Down Expand Up @@ -55,9 +52,6 @@
"GWDecoder",
"GWEncoder",
"VariationalGWEncoder",
"GWInterfaceBase",
"GWInterface",
"VariationalGWInterface",
"GWModuleBase",
"GWModule",
"VariationalGWModule",
Expand Down
91 changes: 62 additions & 29 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
from torch.nn import ModuleDict
from torch.nn import Module, ModuleDict
from torch.optim.lr_scheduler import OneCycleLR

from shimmer.modules.contrastive_loss import (
Expand All @@ -16,7 +16,6 @@
)
from shimmer.modules.domain import DomainModule
from shimmer.modules.gw_module import (
GWInterfaceBase,
GWModule,
GWModuleBase,
GWModuleFusion,
Expand Down Expand Up @@ -87,7 +86,6 @@ class GlobalWorkspaceBase(LightningModule):
def __init__(
self,
gw_mod: GWModuleBase,
domain_mods: Mapping[str, DomainModule],
loss_mod: GWLossesBase,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
Expand All @@ -97,9 +95,6 @@ def __init__(
Args:
gw_mod (`GWModuleBase`): the GWModule
domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
connected to the GW. Keys are domain names, values are the
`DomainModule`.
loss_mod (`GWLossesBase`): module to compute the GW losses.
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
Expand All @@ -114,14 +109,14 @@ def __init__(
"loss_mod",
"domain_descriptions",
"contrastive_loss",
"gw_interfaces",
"gw_encoders",
"gw_decoders",
]
)

self.gw_mod = gw_mod
""" a `GWModuleBase` implementation."""
self.domain_mods = domain_mods
"""Mapping of `DomainModule`s."""

self.loss_mod = loss_mod
"""The module that computes losses of the GW"""

Expand All @@ -131,6 +126,10 @@ def __init__(
if scheduler_args is not None:
self.scheduler_args.update(scheduler_args)

@property
def domain_mods(self) -> Mapping[str, DomainModule]:
return self.gw_mod.domain_mods

@property
def workspace_dim(self) -> int:
"""Dimension of the GW."""
Expand Down Expand Up @@ -511,7 +510,8 @@ class GlobalWorkspace(GlobalWorkspaceBase):
def __init__(
self,
domain_mods: Mapping[str, DomainModule],
gw_interfaces: Mapping[str, GWInterfaceBase],
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: LossCoefs,
optim_lr: float = 1e-3,
Expand All @@ -526,9 +526,12 @@ def __init__(
domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
connected to the GW. Keys are domain names, values are the
`DomainModule`.
gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain
name to a `GWInterfaceBase` class which role is to encode/decode
gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to encode a
unimodal latent representations into a GW representation (pre fusion).
gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
optim_lr (`float`): learning rate
Expand All @@ -540,8 +543,14 @@ def __init__(
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
"""
gw_mod = GWModule(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)

gw_mod = GWModule(
domain_mods,
workspace_dim,
gw_encoders, # type: ignore
gw_decoders, # type: ignore
)
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
Expand All @@ -555,7 +564,6 @@ def __init__(

super().__init__(
gw_mod,
domain_mods,
loss_mod,
optim_lr,
optim_weight_decay,
Expand All @@ -573,7 +581,8 @@ class VariationalGlobalWorkspace(GlobalWorkspaceBase):
def __init__(
self,
domain_mods: Mapping[str, DomainModule],
gw_interfaces: Mapping[str, GWInterfaceBase],
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: VariationalLossCoefs,
use_var_contrastive_loss: bool = False,
Expand All @@ -590,9 +599,12 @@ def __init__(
domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
connected to the GW. Keys are domain names, values are the
`DomainModule`.
gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain
name to a `GWInterfaceBase` class which role is to encode/decode
gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to encode a
unimodal latent representations into a GW representation (pre fusion).
gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
use_var_contrastive_loss (`bool`): whether to use the variational
Expand All @@ -609,9 +621,15 @@ def __init__(
contrastive loss. Only used if `use_var_contrastive_loss` is set to
`True`.
"""
gw_mod = VariationalGWModule(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)

gw_mod = VariationalGWModule(
domain_mods,
workspace_dim,
gw_encoders, # type: ignore
gw_decoders, # type: ignore
)

if use_var_contrastive_loss:
if var_contrastive_loss is None:
var_contrastive_loss = ContrastiveLossWithUncertainty(
Expand All @@ -637,7 +655,6 @@ def __init__(

super().__init__(
gw_mod,
domain_mods,
loss_mod,
optim_lr,
optim_weight_decay,
Expand All @@ -655,7 +672,8 @@ class GlobalWorkspaceFusion(GlobalWorkspaceBase):
def __init__(
self,
domain_mods: Mapping[str, DomainModule],
gw_interfaces: Mapping[str, GWInterfaceBase],
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
Expand All @@ -669,9 +687,12 @@ def __init__(
domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
connected to the GW. Keys are domain names, values are the
`DomainModule`.
gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain
name to a `GWInterfaceBase` class which role is to encode/decode
gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to encode a
unimodal latent representations into a GW representation (pre fusion).
gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
Expand All @@ -682,8 +703,13 @@ def __init__(
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
"""
gw_mod = GWModuleFusion(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModuleFusion(
domain_mods,
workspace_dim,
gw_encoders, # type: ignore
gw_decoders, # type: ignore
)

if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
Expand All @@ -697,7 +723,6 @@ def __init__(

super().__init__(
gw_mod,
domain_mods,
loss_mod,
optim_lr,
optim_weight_decay,
Expand All @@ -708,7 +733,8 @@ def __init__(
def pretrained_global_workspace(
checkpoint_path: str | Path,
domain_mods: Mapping[str, DomainModule],
gw_interfaces: Mapping[str, GWInterfaceBase],
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: LossCoefs,
contrastive_fn: ContrastiveLossType,
Expand All @@ -722,9 +748,12 @@ def pretrained_global_workspace(
domain_mods (`Mapping[str, DomainModule]`): mapping of the domains
connected to the GW. Keys are domain names, values are the
`DomainModule`.
gw_interfaces (`Mapping[str, GWInterfaceBase]`): mapping for each domain
name to a `GWInterfaceBase` class which role is to encode/decode
gw_encoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to encode a
unimodal latent representations into a GW representation (pre fusion).
gw_decoders (`Mapping[str, torch.nn.Module]`): mapping for each domain
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
contrastive_loss (`ContrastiveLossType`): a contrastive loss
Expand All @@ -739,8 +768,13 @@ def pretrained_global_workspace(
Raises:
`TypeError`: if loaded type is not `GlobalWorkspace`.
"""
gw_mod = GWModule(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(
domain_mods,
workspace_dim,
gw_encoders, # type: ignore
gw_decoders, # type: ignore
)
loss_mod = GWLosses(
gw_mod,
domain_mods,
Expand All @@ -750,7 +784,6 @@ def pretrained_global_workspace(

gw = GlobalWorkspace.load_from_checkpoint(
checkpoint_path,
domain_mods=domain_mods,
gw_mod=gw_mod,
loss_coefs=loss_coefs,
loss_mod=loss_mod,
Expand Down
Loading

0 comments on commit d77608f

Please sign in to comment.