Skip to content

Commit

Permalink
Use functions instead of classes for GlobalWorkspaces.
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Dec 12, 2023
1 parent 91a6f24 commit 0965355
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 135 deletions.
12 changes: 11 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,14 @@ Fix missing individual metrics for translation loss.
Fix wrong module used to compute the cycle losses. Don't do cycle with the same domain as target and source.

# 0.2.0
Add callback on_before_gw_encode and individual compute_losses for each loss type.
Add callback on\_before\_gw\_encode and individual compute\_losses for each loss type.
Fix bugs

# 0.3.0
* Breaking change: remove `DeterministGlobaleWorkspace` and `VariationalGlobalWorkspace`
in favor of the functions: `global_workspace` and `variational_global_workspace`.
* Allow setting custom GW encoders and decoders.
* Breaking change: remove `self.input_dim`, `self.encoder_hidden_dim`,
`self.encoder_n_layers`, `self.decoder_hidden_dim`, and `self.decoder_n_layers`
in `GWModule`s.

Empty file removed main.yaml
Empty file.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "shimmer"
version = "0.2.0"
version = "0.3.0"
description = "A light GLoW"
authors = ["bdvllrs <[email protected]>"]
license = "MIT"
Expand Down
145 changes: 74 additions & 71 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import torch
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig
from torch.nn import ModuleDict
from torch.nn import Module, ModuleDict
from torch.optim.lr_scheduler import OneCycleLR

from shimmer.modules.dict_buffer import DictBuffer
from shimmer.modules.domain import DomainDescription, DomainModule
from shimmer.modules.gw_module import (DeterministicGWModule, GWModule,
from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder,
GWEncoder, GWModule,
VariationalGWModule)
from shimmer.modules.losses import (DeterministicGWLosses, GWLosses, LatentsT,
VariationalGWLosses)
Expand Down Expand Up @@ -251,72 +252,74 @@ def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
}


class DeterministicGlobalWorkspace(GlobalWorkspace):
def __init__(
self,
domain_descriptions: Mapping[str, DomainDescription],
latent_dim: int,
loss_coefs: dict[str, torch.Tensor],
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
) -> None:
gw_mod = DeterministicGWModule(domain_descriptions, latent_dim)

domain_mods = {
name: domain.module for name, domain in domain_descriptions.items()
}
for mod in domain_mods.values():
mod.freeze()
domain_mods = cast(dict[str, DomainModule], ModuleDict(domain_mods))

coef_buffers = DictBuffer(loss_coefs)

loss_mod = DeterministicGWLosses(gw_mod, domain_mods, coef_buffers)

super().__init__(
gw_mod,
domain_mods,
coef_buffers,
loss_mod,
optim_lr,
optim_weight_decay,
scheduler_args,
)


class VariationalGlobalWorkspace(GlobalWorkspace):
def __init__(
self,
domain_descriptions: Mapping[str, DomainDescription],
latent_dim: int,
loss_coefs: dict[str, torch.Tensor],
var_contrastive_loss: bool = False,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
) -> None:
gw_mod = VariationalGWModule(domain_descriptions, latent_dim)

domain_mods = {
name: domain.module for name, domain in domain_descriptions.items()
}
for mod in domain_mods.values():
mod.freeze()
domain_mods = cast(dict[str, DomainModule], ModuleDict(domain_mods))

coef_buffers = DictBuffer(loss_coefs)

loss_mod = VariationalGWLosses(
gw_mod, domain_mods, coef_buffers, var_contrastive_loss
)

super().__init__(
gw_mod,
domain_mods,
coef_buffers,
loss_mod,
optim_lr,
optim_weight_decay,
scheduler_args,
)
def global_workspace(
domain_descriptions: Mapping[str, DomainDescription],
latent_dim: int,
loss_coefs: dict[str, torch.Tensor],
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
gw_encoders: Mapping[str, Module] | None = None,
gw_decoders: Mapping[str, Module] | None = None,
) -> GlobalWorkspace:
gw_mod = DeterministicGWModule(
domain_descriptions, latent_dim, gw_encoders, gw_decoders
)

domain_mods = {
name: domain.module for name, domain in domain_descriptions.items()
}
for mod in domain_mods.values():
mod.freeze()
domain_mods = cast(dict[str, DomainModule], ModuleDict(domain_mods))

coef_buffers = DictBuffer(loss_coefs)

loss_mod = DeterministicGWLosses(gw_mod, domain_mods, coef_buffers)
return GlobalWorkspace(
gw_mod,
domain_mods,
coef_buffers,
loss_mod,
optim_lr,
optim_weight_decay,
scheduler_args,
)


def variational_global_workspace(
domain_descriptions: Mapping[str, DomainDescription],
latent_dim: int,
loss_coefs: dict[str, torch.Tensor],
var_contrastive_loss: bool = False,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
gw_encoders: Mapping[str, Module] | None = None,
gw_decoders: Mapping[str, Module] | None = None,
) -> GlobalWorkspace:
gw_mod = VariationalGWModule(
domain_descriptions, latent_dim, gw_encoders, gw_decoders
)

domain_mods = {
name: domain.module for name, domain in domain_descriptions.items()
}
for mod in domain_mods.values():
mod.freeze()
domain_mods = cast(dict[str, DomainModule], ModuleDict(domain_mods))

coef_buffers = DictBuffer(loss_coefs)

loss_mod = VariationalGWLosses(
gw_mod, domain_mods, coef_buffers, var_contrastive_loss
)
return GlobalWorkspace(
gw_mod,
domain_mods,
coef_buffers,
loss_mod,
optim_lr,
optim_weight_decay,
scheduler_args,
)
112 changes: 50 additions & 62 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,52 +143,53 @@ def cycle(
raise NotImplementedError


def default_encoders(
domain_descriptions: Mapping[str, DomainDescription], latent_dim: int
) -> dict[str, GWEncoder]:
return {
name: GWEncoder(
domain.latent_dim,
domain.encoder_hidden_dim,
latent_dim,
domain.encoder_n_layers,
)
for name, domain in domain_descriptions.items()
}


def default_decoders(
domain_descriptions: Mapping[str, DomainDescription], latent_dim: int
) -> dict[str, GWDecoder]:
return {
name: GWDecoder(
domain.latent_dim,
domain.decoder_hidden_dim,
latent_dim,
domain.decoder_n_layers,
)
for name, domain in domain_descriptions.items()
}


class DeterministicGWModule(GWModule):
def __init__(
self,
domain_descriptions: Mapping[str, DomainDescription],
latent_dim: int,
encoders: Mapping[str, nn.Module] | None = None,
decoders: Mapping[str, nn.Module] | None = None,
):
super().__init__()

self.domains = set(domain_descriptions.keys())
self.domain_descr = domain_descriptions
self.latent_dim = latent_dim

self.input_dim: dict[str, int] = {}
self.encoder_hidden_dim: dict[str, int] = {}
self.encoder_n_layers: dict[str, int] = {}
self.decoder_hidden_dim: dict[str, int] = {}
self.decoder_n_layers: dict[str, int] = {}

for name, domain in domain_descriptions.items():
self.input_dim[name] = domain.latent_dim
self.encoder_hidden_dim[name] = domain.encoder_hidden_dim
self.encoder_n_layers[name] = domain.encoder_n_layers
self.decoder_hidden_dim[name] = domain.decoder_hidden_dim
self.decoder_n_layers[name] = domain.decoder_n_layers

self.encoders = nn.ModuleDict(
{
domain: GWEncoder(
self.input_dim[domain],
self.encoder_hidden_dim[domain],
self.latent_dim,
self.encoder_n_layers[domain],
)
for domain in self.domains
}
encoders or default_encoders(domain_descriptions, latent_dim)
)
self.decoders = nn.ModuleDict(
{
domain: GWDecoder(
self.latent_dim,
self.decoder_hidden_dim[domain],
self.input_dim[domain],
self.decoder_n_layers[domain],
)
for domain in self.domains
}
decoders or default_decoders(domain_descriptions, latent_dim)
)

def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
Expand Down Expand Up @@ -223,52 +224,39 @@ def cycle(
}


def default_var_encoders(
domain_descriptions: Mapping[str, DomainDescription], latent_dim: int
) -> dict[str, VariationalGWEncoder]:
return {
name: VariationalGWEncoder(
domain.latent_dim,
domain.encoder_hidden_dim,
latent_dim,
domain.encoder_n_layers,
)
for name, domain in domain_descriptions.items()
}


class VariationalGWModule(GWModule):
def __init__(
self,
domain_descriptions: Mapping[str, DomainDescription],
latent_dim: int,
encoders: Mapping[str, nn.Module] | None = None,
decoders: Mapping[str, nn.Module] | None = None,
):
super().__init__()

self.domains = set(domain_descriptions.keys())
self.domain_descr = domain_descriptions
self.latent_dim = latent_dim

self.input_dim: dict[str, int] = {}
self.encoder_hidden_dim: dict[str, int] = {}
self.encoder_n_layers: dict[str, int] = {}
self.decoder_hidden_dim: dict[str, int] = {}
self.decoder_n_layers: dict[str, int] = {}

for name, domain in domain_descriptions.items():
self.input_dim[name] = domain.latent_dim
self.encoder_hidden_dim[name] = domain.encoder_hidden_dim
self.encoder_n_layers[name] = domain.encoder_n_layers
self.decoder_hidden_dim[name] = domain.decoder_hidden_dim
self.decoder_n_layers[name] = domain.decoder_n_layers

self.encoders = nn.ModuleDict(
{
domain: VariationalGWEncoder(
self.input_dim[domain],
self.encoder_hidden_dim[domain],
self.latent_dim,
self.encoder_n_layers[domain],
)
for domain in self.domains
}
encoders or default_var_encoders(domain_descriptions, latent_dim)
)
self.decoders = nn.ModuleDict(
{
domain: GWDecoder(
self.latent_dim,
self.decoder_hidden_dim[domain],
self.input_dim[domain],
self.decoder_n_layers[domain],
)
for domain in self.domains
}
decoders or default_decoders(domain_descriptions, latent_dim)
)

def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
Expand Down

0 comments on commit 0965355

Please sign in to comment.