Skip to content

Commit

Permalink
Refactor GlobalWorkspaceWithUncertainty and how to init the GWLosses mod
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 8, 2024
1 parent cf239ef commit 0cdd417
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,26 +596,28 @@ def __init__(
domain_mods, workspace_dim, gw_encoders, gw_decoders
)

if use_cont_loss_with_uncertainty:
if cont_loss_with_uncertainty is None:
cont_loss_with_uncertainty = ContrastiveLossWithUncertainty(
torch.tensor([1]).log(), "mean", learn_logit_scale
)
loss_mod = GWLossesWithUncertainty(
gw_mod,
domain_mods,
loss_coefs,
cont_fn_with_uncertainty=cont_loss_with_uncertainty,
if use_cont_loss_with_uncertainty and cont_loss_with_uncertainty is None:
cont_loss_with_uncertainty = ContrastiveLossWithUncertainty(
torch.tensor([1]).log(), "mean", learn_logit_scale
)
else:
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
)
loss_mod = GWLossesWithUncertainty(
gw_mod, domain_mods, loss_coefs, contrastive_fn=contrastive_loss
elif not use_cont_loss_with_uncertainty and contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
)

if use_cont_loss_with_uncertainty:
contrastive_loss = None
else:
cont_loss_with_uncertainty = None

loss_mod = GWLossesWithUncertainty(
gw_mod,
domain_mods,
loss_coefs,
contrastive_loss,
cont_loss_with_uncertainty,
)

super().__init__(gw_mod, loss_mod, optim_lr, optim_weight_decay, scheduler_args)


Expand Down

0 comments on commit 0cdd417

Please sign in to comment.