Skip to content

Commit

Permalink
Simplify API for WithUncertainty modules. (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored Apr 4, 2024
1 parent b5cdbee commit c2e69a8
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 413 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,5 @@ refers to `DeterministicGlobalWorkspace`.
methods in `GlobalWorkspaceBase`.
* Remove on_before_gw_encode_{loss} callbacks to allow sharing computation between
loss functions.
* Remove many _with_uncertainty functions. The GWModuleWithUncertainty now behaves like
the other GWModules.
26 changes: 8 additions & 18 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
from shimmer.modules.contrastive_loss import (
ContrastiveLoss,
ContrastiveLossType,
ContrastiveLossWithUncertainty,
ContrastiveLossWithUncertaintyType,
contrastive_loss,
contrastive_loss_with_uncertainty,
)
from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.global_workspace import (
Expand All @@ -20,7 +17,6 @@
GWDecoder,
GWEncoder,
GWEncoderLinear,
GWEncoderWithUncertainty,
GWModule,
GWModuleBase,
GWModuleWithUncertainty,
Expand All @@ -31,17 +27,17 @@
GWLossesWithUncertainty,
LossCoefs,
)
from shimmer.modules.selection import (
RandomSelection,
SelectionBase,
SingleDomainSelection,
)
from shimmer.modules.utils import (
batch_cycles,
batch_cycles_with_uncertainty,
batch_demi_cycles,
batch_demi_cycles_with_uncertainty,
batch_translations,
batch_translations_with_uncertainty,
cycle,
cycle_with_uncertainty,
translation,
translation_with_uncertainty,
)
from shimmer.types import (
LatentsDomainGroupDT,
Expand Down Expand Up @@ -79,16 +75,12 @@
"GWDecoder",
"GWEncoder",
"GWEncoderLinear",
"GWEncoderWithUncertainty",
"GWModuleBase",
"GWModule",
"GWModuleWithUncertainty",
"ContrastiveLossType",
"ContrastiveLossWithUncertaintyType",
"contrastive_loss",
"ContrastiveLoss",
"contrastive_loss_with_uncertainty",
"ContrastiveLossWithUncertainty",
"LossCoefs",
"GWLossesBase",
"GWLosses",
Expand All @@ -99,12 +91,10 @@
"batch_translations",
"cycle",
"translation",
"cycle_with_uncertainty",
"translation_with_uncertainty",
"batch_translations_with_uncertainty",
"batch_demi_cycles_with_uncertainty",
"batch_cycles_with_uncertainty",
"MIGRATION_DIR",
"migrate_model",
"SaveMigrations",
"RandomSelection",
"SelectionBase",
"SingleDomainSelection",
]
20 changes: 8 additions & 12 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
GWDecoder,
GWEncoder,
GWEncoderLinear,
GWEncoderWithUncertainty,
GWModule,
GWModuleBase,
GWModuleWithUncertainty,
Expand All @@ -31,17 +30,17 @@
GWLossesWithUncertainty,
LossCoefs,
)
from shimmer.modules.selection import (
RandomSelection,
SelectionBase,
SingleDomainSelection,
)
from shimmer.modules.utils import (
batch_cycles,
batch_cycles_with_uncertainty,
batch_demi_cycles,
batch_demi_cycles_with_uncertainty,
batch_translations,
batch_translations_with_uncertainty,
cycle,
cycle_with_uncertainty,
translation,
translation_with_uncertainty,
)
from shimmer.modules.vae import (
VAE,
Expand All @@ -64,7 +63,6 @@
"GWDecoder",
"GWEncoder",
"GWEncoderLinear",
"GWEncoderWithUncertainty",
"GWModuleBase",
"GWModule",
"GWModuleWithUncertainty",
Expand All @@ -90,9 +88,7 @@
"batch_translations",
"cycle",
"translation",
"cycle_with_uncertainty",
"translation_with_uncertainty",
"batch_translations_with_uncertainty",
"batch_demi_cycles_with_uncertainty",
"batch_cycles_with_uncertainty",
"RandomSelection",
"SelectionBase",
"SingleDomainSelection",
]
47 changes: 8 additions & 39 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@
from torch.nn import Module, ModuleDict
from torch.optim.lr_scheduler import OneCycleLR

from shimmer.modules.contrastive_loss import (
ContrastiveLoss,
ContrastiveLossType,
ContrastiveLossWithUncertainty,
ContrastiveLossWithUncertaintyType,
)
from shimmer.modules.contrastive_loss import ContrastiveLoss, ContrastiveLossType
from shimmer.modules.domain import DomainModule
from shimmer.modules.gw_module import (
GWModule,
Expand All @@ -29,14 +24,7 @@
LossCoefs,
)
from shimmer.modules.selection import SelectionBase, SingleDomainSelection
from shimmer.modules.utils import (
batch_cycles,
batch_cycles_with_uncertainty,
batch_demi_cycles,
batch_demi_cycles_with_uncertainty,
batch_translations,
batch_translations_with_uncertainty,
)
from shimmer.modules.utils import batch_cycles, batch_demi_cycles, batch_translations
from shimmer.types import (
LatentsDomainGroupsDT,
LatentsDomainGroupsT,
Expand Down Expand Up @@ -564,13 +552,11 @@ def __init__(
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: LossCoefs,
use_cont_loss_with_uncertainty: bool = False,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
contrastive_loss: ContrastiveLossType | None = None,
cont_loss_with_uncertainty: ContrastiveLossWithUncertaintyType | None = None,
) -> None:
"""
Initializes a Global Workspace
Expand All @@ -587,9 +573,6 @@ def __init__(
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
use_cont_loss_with_uncertainty (`bool`): whether to use the contrastive
loss with uncertainty which uses means and log variance for
computations.
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
Expand All @@ -598,9 +581,6 @@ def __init__(
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
cont_loss_with_uncertainty (`ContrastiveLossWithUncertaintyType | None`): a
contrastive loss with uncertainty.
Only used if `use_cont_loss_with_uncertainty` is set to `True`.
"""
domain_mods = freeze_domain_modules(domain_mods)

Expand All @@ -610,27 +590,16 @@ def __init__(

selection_mod = SingleDomainSelection()

if use_cont_loss_with_uncertainty and cont_loss_with_uncertainty is None:
cont_loss_with_uncertainty = ContrastiveLossWithUncertainty(
torch.tensor([1]).log(), "mean", learn_logit_scale
)
elif not use_cont_loss_with_uncertainty and contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
)

if use_cont_loss_with_uncertainty:
contrastive_loss = None
else:
cont_loss_with_uncertainty = None
contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
)

loss_mod = GWLossesWithUncertainty(
gw_mod,
selection_mod,
domain_mods,
loss_coefs,
contrastive_loss,
cont_loss_with_uncertainty,
)

super().__init__(
Expand All @@ -656,13 +625,13 @@ def forward( # type: ignore
`GWPredictions`: the predictions on the batch.
"""
return GWPredictions(
demi_cycles=batch_demi_cycles_with_uncertainty(
demi_cycles=batch_demi_cycles(
self.gw_mod, self.selection_mod, latent_domains
),
cycles=batch_cycles_with_uncertainty(
cycles=batch_cycles(
self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
),
translations=batch_translations_with_uncertainty(
translations=batch_translations(
self.gw_mod, self.selection_mod, latent_domains
),
**super().forward(latent_domains),
Expand Down
Loading

0 comments on commit c2e69a8

Please sign in to comment.