Skip to content

Commit

Permalink
Use RandomSeletion for Uncertainty module (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored Apr 16, 2024
1 parent c6f0bfe commit 7671609
Showing 1 changed file with 65 additions and 63 deletions.
128 changes: 65 additions & 63 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,13 +545,10 @@ def forward( # type: ignore
)


class GlobalWorkspaceWithUncertainty(
GlobalWorkspaceBase[
GWModuleWithUncertainty, SingleDomainSelection, GWLossesWithUncertainty
]
class GlobalWorkspaceFusion(
GlobalWorkspaceBase[GWModule, RandomSelection, GWLossesFusion]
):
"""
A simple 2-domains max GlobalWorkspaceBase with uncertainty.
"""The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
This is used to simplify a Global Workspace instanciation and only overrides the
`__init__` method.
Expand All @@ -563,7 +560,8 @@ def __init__(
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: LossCoefs,
loss_coefs: BroadcastLossCoefs,
selection_temperature: float = 0.2,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
Expand All @@ -584,7 +582,9 @@ def __init__(
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
selection_temperature (`float`): temperature value for the RandomSelection
module.
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
Expand All @@ -595,23 +595,16 @@ def __init__(
contrastive losses.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)

gw_mod = GWModuleWithUncertainty(
domain_mods, workspace_dim, gw_encoders, gw_decoders
)

selection_mod = SingleDomainSelection()

contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
)
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
)

loss_mod = GWLossesWithUncertainty(
gw_mod,
selection_mod,
domain_mods,
loss_coefs,
contrastive_loss,
selection_mod = RandomSelection(selection_temperature)
loss_mod = GWLossesFusion(
gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss
)

super().__init__(
Expand All @@ -623,37 +616,14 @@ def __init__(
scheduler_args,
)

def forward( # type: ignore
self,
latent_domains: LatentsDomainGroupsT,
) -> GWPredictions:
"""
Computes demi-cycles, cycles, and translations.
Args:
latent_domains (`LatentsT`): Groups of domains for the computation.
Returns:
`GWPredictions`: the predictions on the batch.
"""
return GWPredictions(
demi_cycles=batch_demi_cycles(
self.gw_mod, self.selection_mod, latent_domains
),
cycles=batch_cycles(
self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
),
translations=batch_translations(
self.gw_mod, self.selection_mod, latent_domains
),
**super().forward(latent_domains),
)


class GlobalWorkspaceFusion(
GlobalWorkspaceBase[GWModule, RandomSelection, GWLossesFusion]
class GlobalWorkspaceWithUncertainty(
GlobalWorkspaceBase[
GWModuleWithUncertainty, RandomSelection, GWLossesWithUncertainty
]
):
"""The 2-domain fusion (with broadcast loss) flavor of GlobalWorkspaceBase.
"""
A simple 2-domains max GlobalWorkspaceBase with uncertainty.
This is used to simplify a Global Workspace instanciation and only overrides the
`__init__` method.
Expand All @@ -665,7 +635,7 @@ def __init__(
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: BroadcastLossCoefs,
loss_coefs: LossCoefs,
selection_temperature: float = 0.2,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
Expand All @@ -687,9 +657,8 @@ def __init__(
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
selection_temperature (`float`): temperature value for the RandomSelection
module.
loss_coefs (`LossCoefs`): loss coefficients
selection_temperature (`float`): temperature for `RandomSelection`
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
Expand All @@ -700,16 +669,23 @@ def __init__(
contrastive losses.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders)

if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
)
gw_mod = GWModuleWithUncertainty(
domain_mods, workspace_dim, gw_encoders, gw_decoders
)

selection_mod = RandomSelection(selection_temperature)
loss_mod = GWLossesFusion(
gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_loss

contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
)

loss_mod = GWLossesWithUncertainty(
gw_mod,
selection_mod,
domain_mods,
loss_coefs,
contrastive_loss,
)

super().__init__(
Expand All @@ -721,6 +697,32 @@ def __init__(
scheduler_args,
)

def forward( # type: ignore
self,
latent_domains: LatentsDomainGroupsT,
) -> GWPredictions:
"""
Computes demi-cycles, cycles, and translations.
Args:
latent_domains (`LatentsT`): Groups of domains for the computation.
Returns:
`GWPredictions`: the predictions on the batch.
"""
return GWPredictions(
demi_cycles=batch_demi_cycles(
self.gw_mod, self.selection_mod, latent_domains
),
cycles=batch_cycles(
self.gw_mod, self.selection_mod, latent_domains, self.domain_mods.keys()
),
translations=batch_translations(
self.gw_mod, self.selection_mod, latent_domains
),
**super().forward(latent_domains),
)


def pretrained_global_workspace(
checkpoint_path: str | Path,
Expand Down

0 comments on commit 7671609

Please sign in to comment.