diff --git a/CHANGELOG.md b/CHANGELOG.md index 44709e37..1b9f2a27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,4 +8,14 @@ Fix missing individual metrics for translation loss. Fix wrong module used to compute the cycle losses. Don't do cycle with the same domain as target and source. # 0.2.0 -Add callback on_before_gw_encode and individual compute_losses for each loss type. +Add callback on\_before\_gw\_encode and individual compute\_losses for each loss type. +Fix bugs + +# 0.3.0 +* Breaking change: remove `DeterministGlobaleWorkspace` and `VariationalGlobalWorkspace` +in favor of the functions: `global_workspace` and `variational_global_workspace`. +* Allow setting custom GW encoders and decoders. +* Breaking change: remove `self.input_dim`, `self.encoder_hidden_dim`, +`self.encoder_n_layers`, `self.decoder_hidden_dim`, and `self.decoder_n_layers` +in `GWModule`s. + diff --git a/main.yaml b/main.yaml deleted file mode 100644 index e69de29b..00000000 diff --git a/pyproject.toml b/pyproject.toml index 33a854f5..b9fe815f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "shimmer" -version = "0.2.0" +version = "0.3.0" description = "A light GLoW" authors = ["bdvllrs "] license = "MIT" diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index cd46b552..a39ee647 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -4,12 +4,13 @@ import torch from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import OptimizerLRSchedulerConfig -from torch.nn import ModuleDict +from torch.nn import Module, ModuleDict from torch.optim.lr_scheduler import OneCycleLR from shimmer.modules.dict_buffer import DictBuffer from shimmer.modules.domain import DomainDescription, DomainModule -from shimmer.modules.gw_module import (DeterministicGWModule, GWModule, +from shimmer.modules.gw_module import (DeterministicGWModule, GWDecoder, + GWEncoder, GWModule, VariationalGWModule) from shimmer.modules.losses import (DeterministicGWLosses, GWLosses, LatentsT, VariationalGWLosses) @@ -251,72 +252,74 @@ def configure_optimizers(self) -> OptimizerLRSchedulerConfig: } -class DeterministicGlobalWorkspace(GlobalWorkspace): - def __init__( - self, - domain_descriptions: Mapping[str, DomainDescription], - latent_dim: int, - loss_coefs: dict[str, torch.Tensor], - optim_lr: float = 1e-3, - optim_weight_decay: float = 0.0, - scheduler_args: SchedulerArgs | None = None, - ) -> None: - gw_mod = DeterministicGWModule(domain_descriptions, latent_dim) - - domain_mods = { - name: domain.module for name, domain in domain_descriptions.items() - } - for mod in domain_mods.values(): - mod.freeze() - domain_mods = cast(dict[str, DomainModule], ModuleDict(domain_mods)) - - coef_buffers = DictBuffer(loss_coefs) - - loss_mod = DeterministicGWLosses(gw_mod, domain_mods, coef_buffers) - - super().__init__( - gw_mod, - domain_mods, - coef_buffers, - loss_mod, - optim_lr, - optim_weight_decay, - scheduler_args, - ) - - -class VariationalGlobalWorkspace(GlobalWorkspace): - def __init__( - self, - domain_descriptions: Mapping[str, DomainDescription], - latent_dim: int, - loss_coefs: dict[str, torch.Tensor], - var_contrastive_loss: bool = False, - optim_lr: float = 1e-3, - optim_weight_decay: float = 0.0, - scheduler_args: SchedulerArgs | None = None, - ) -> None: - gw_mod = VariationalGWModule(domain_descriptions, latent_dim) - - domain_mods = { - name: domain.module for name, domain in domain_descriptions.items() - } - for mod in domain_mods.values(): - mod.freeze() - domain_mods = cast(dict[str, DomainModule], ModuleDict(domain_mods)) - - coef_buffers = DictBuffer(loss_coefs) - - loss_mod = VariationalGWLosses( - gw_mod, domain_mods, coef_buffers, var_contrastive_loss - ) - - super().__init__( - gw_mod, - domain_mods, - coef_buffers, - loss_mod, - optim_lr, - optim_weight_decay, - scheduler_args, - ) +def global_workspace( + domain_descriptions: Mapping[str, DomainDescription], + latent_dim: int, + loss_coefs: dict[str, torch.Tensor], + optim_lr: float = 1e-3, + optim_weight_decay: float = 0.0, + scheduler_args: SchedulerArgs | None = None, + gw_encoders: Mapping[str, Module] | None = None, + gw_decoders: Mapping[str, Module] | None = None, +) -> GlobalWorkspace: + gw_mod = DeterministicGWModule( + domain_descriptions, latent_dim, gw_encoders, gw_decoders + ) + + domain_mods = { + name: domain.module for name, domain in domain_descriptions.items() + } + for mod in domain_mods.values(): + mod.freeze() + domain_mods = cast(dict[str, DomainModule], ModuleDict(domain_mods)) + + coef_buffers = DictBuffer(loss_coefs) + + loss_mod = DeterministicGWLosses(gw_mod, domain_mods, coef_buffers) + return GlobalWorkspace( + gw_mod, + domain_mods, + coef_buffers, + loss_mod, + optim_lr, + optim_weight_decay, + scheduler_args, + ) + + +def variational_global_workspace( + domain_descriptions: Mapping[str, DomainDescription], + latent_dim: int, + loss_coefs: dict[str, torch.Tensor], + var_contrastive_loss: bool = False, + optim_lr: float = 1e-3, + optim_weight_decay: float = 0.0, + scheduler_args: SchedulerArgs | None = None, + gw_encoders: Mapping[str, Module] | None = None, + gw_decoders: Mapping[str, Module] | None = None, +) -> GlobalWorkspace: + gw_mod = VariationalGWModule( + domain_descriptions, latent_dim, gw_encoders, gw_decoders + ) + + domain_mods = { + name: domain.module for name, domain in domain_descriptions.items() + } + for mod in domain_mods.values(): + mod.freeze() + domain_mods = cast(dict[str, DomainModule], ModuleDict(domain_mods)) + + coef_buffers = DictBuffer(loss_coefs) + + loss_mod = VariationalGWLosses( + gw_mod, domain_mods, coef_buffers, var_contrastive_loss + ) + return GlobalWorkspace( + gw_mod, + domain_mods, + coef_buffers, + loss_mod, + optim_lr, + optim_weight_decay, + scheduler_args, + ) diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index d7496ade..57ee60fa 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -143,11 +143,41 @@ def cycle( raise NotImplementedError +def default_encoders( + domain_descriptions: Mapping[str, DomainDescription], latent_dim: int +) -> dict[str, GWEncoder]: + return { + name: GWEncoder( + domain.latent_dim, + domain.encoder_hidden_dim, + latent_dim, + domain.encoder_n_layers, + ) + for name, domain in domain_descriptions.items() + } + + +def default_decoders( + domain_descriptions: Mapping[str, DomainDescription], latent_dim: int +) -> dict[str, GWDecoder]: + return { + name: GWDecoder( + domain.latent_dim, + domain.decoder_hidden_dim, + latent_dim, + domain.decoder_n_layers, + ) + for name, domain in domain_descriptions.items() + } + + class DeterministicGWModule(GWModule): def __init__( self, domain_descriptions: Mapping[str, DomainDescription], latent_dim: int, + encoders: Mapping[str, nn.Module] | None = None, + decoders: Mapping[str, nn.Module] | None = None, ): super().__init__() @@ -155,40 +185,11 @@ def __init__( self.domain_descr = domain_descriptions self.latent_dim = latent_dim - self.input_dim: dict[str, int] = {} - self.encoder_hidden_dim: dict[str, int] = {} - self.encoder_n_layers: dict[str, int] = {} - self.decoder_hidden_dim: dict[str, int] = {} - self.decoder_n_layers: dict[str, int] = {} - - for name, domain in domain_descriptions.items(): - self.input_dim[name] = domain.latent_dim - self.encoder_hidden_dim[name] = domain.encoder_hidden_dim - self.encoder_n_layers[name] = domain.encoder_n_layers - self.decoder_hidden_dim[name] = domain.decoder_hidden_dim - self.decoder_n_layers[name] = domain.decoder_n_layers - self.encoders = nn.ModuleDict( - { - domain: GWEncoder( - self.input_dim[domain], - self.encoder_hidden_dim[domain], - self.latent_dim, - self.encoder_n_layers[domain], - ) - for domain in self.domains - } + encoders or default_encoders(domain_descriptions, latent_dim) ) self.decoders = nn.ModuleDict( - { - domain: GWDecoder( - self.latent_dim, - self.decoder_hidden_dim[domain], - self.input_dim[domain], - self.decoder_n_layers[domain], - ) - for domain in self.domains - } + decoders or default_decoders(domain_descriptions, latent_dim) ) def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: @@ -223,11 +224,27 @@ def cycle( } +def default_var_encoders( + domain_descriptions: Mapping[str, DomainDescription], latent_dim: int +) -> dict[str, VariationalGWEncoder]: + return { + name: VariationalGWEncoder( + domain.latent_dim, + domain.encoder_hidden_dim, + latent_dim, + domain.encoder_n_layers, + ) + for name, domain in domain_descriptions.items() + } + + class VariationalGWModule(GWModule): def __init__( self, domain_descriptions: Mapping[str, DomainDescription], latent_dim: int, + encoders: Mapping[str, nn.Module] | None = None, + decoders: Mapping[str, nn.Module] | None = None, ): super().__init__() @@ -235,40 +252,11 @@ def __init__( self.domain_descr = domain_descriptions self.latent_dim = latent_dim - self.input_dim: dict[str, int] = {} - self.encoder_hidden_dim: dict[str, int] = {} - self.encoder_n_layers: dict[str, int] = {} - self.decoder_hidden_dim: dict[str, int] = {} - self.decoder_n_layers: dict[str, int] = {} - - for name, domain in domain_descriptions.items(): - self.input_dim[name] = domain.latent_dim - self.encoder_hidden_dim[name] = domain.encoder_hidden_dim - self.encoder_n_layers[name] = domain.encoder_n_layers - self.decoder_hidden_dim[name] = domain.decoder_hidden_dim - self.decoder_n_layers[name] = domain.decoder_n_layers - self.encoders = nn.ModuleDict( - { - domain: VariationalGWEncoder( - self.input_dim[domain], - self.encoder_hidden_dim[domain], - self.latent_dim, - self.encoder_n_layers[domain], - ) - for domain in self.domains - } + encoders or default_var_encoders(domain_descriptions, latent_dim) ) self.decoders = nn.ModuleDict( - { - domain: GWDecoder( - self.latent_dim, - self.decoder_hidden_dim[domain], - self.input_dim[domain], - self.decoder_n_layers[domain], - ) - for domain in self.domains - } + decoders or default_decoders(domain_descriptions, latent_dim) ) def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: