Skip to content

Commit

Permalink
Add option to switch off bayesian cont loss in bayesian GW (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored May 21, 2024
1 parent d2b67c0 commit c0eb83b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
4 changes: 4 additions & 0 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,7 @@ def __init__(
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
use_normalized_constrastive: bool = True,
contrastive_loss: ContrastiveLossType | None = None,
) -> None:
"""
Expand All @@ -706,6 +707,8 @@ def __init__(
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
learn_logit_scale (`bool`): whether to learn the contrastive learning
contrastive loss when using the default contrastive loss.
use_normalized_constrastive (`bool`): whether to use the normalized cont
loss by the precision coefs
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
Expand Down Expand Up @@ -733,6 +736,7 @@ def __init__(
domain_mods,
loss_coefs,
contrastive_loss,
use_normalized_constrastive,
)

super().__init__(
Expand Down
16 changes: 13 additions & 3 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,13 @@ def contrastive_loss_bayesian(
loss_output = contrastive_fn(
z1 * coef[0] * coef[1], z2 * coef[0] * coef[1]
)
loss_output_no_norm = contrastive_fn(z1, z2)

losses[loss_name] = loss_output.loss
metrics.update(
{f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
)
metrics[f"unnorm_{loss_name}"] = loss_output_no_norm.loss

losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean()
losses.update(metrics)
Expand Down Expand Up @@ -763,6 +766,7 @@ def __init__(
domain_mods: dict[str, DomainModule],
loss_coefs: BroadcastLossCoefs,
contrastive_fn: ContrastiveLossType,
use_normalized_constrastive: bool = True,
):
"""
Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian
Expand All @@ -775,6 +779,8 @@ def __init__(
loss_coefs (`BroadcastLossCoefs`): loss coefficients
contrastive_fn (`ContrastiveLossType`): the contrastive function
to use in contrastive loss
use_normalized_constrastive (`bool`): whether to use the normalized cont
loss by the precision coefs
"""
super().__init__()

Expand All @@ -795,6 +801,8 @@ def __init__(
Contrastive loss to use.
"""

self.use_normalized_constrastive = use_normalized_constrastive

def contrastive_loss(
self, latent_domains: LatentsDomainGroupsT
) -> dict[str, torch.Tensor]:
Expand All @@ -807,9 +815,11 @@ def contrastive_loss(
Returns:
`dict[str, torch.Tensor]`: a dict of metrics.
"""
return contrastive_loss_bayesian(
self.gw_mod, latent_domains, self.contrastive_fn
)
if self.use_normalized_constrastive:
return contrastive_loss_bayesian(
self.gw_mod, latent_domains, self.contrastive_fn
)
return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)

def broadcast_loss(
self, latent_domains: LatentsDomainGroupsT
Expand Down

0 comments on commit c0eb83b

Please sign in to comment.