From c0eb83b5c01c95aa62a83b7a600c0a5cc4409575 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Tue, 21 May 2024 15:18:09 +0200 Subject: [PATCH] Add option to switch off bayesian cont loss in bayesian GW (#71) --- shimmer/modules/global_workspace.py | 4 ++++ shimmer/modules/losses.py | 16 +++++++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index cbf3d864..7aec98ff 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -682,6 +682,7 @@ def __init__( optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, + use_normalized_constrastive: bool = True, contrastive_loss: ContrastiveLossType | None = None, ) -> None: """ @@ -706,6 +707,8 @@ def __init__( scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments learn_logit_scale (`bool`): whether to learn the contrastive learning contrastive loss when using the default contrastive loss. + use_normalized_constrastive (`bool`): whether to use the normalized cont + loss by the precision coefs contrastive_loss (`ContrastiveLossType | None`): a contrastive loss function used for alignment. `learn_logit_scale` will not affect custom contrastive losses. @@ -733,6 +736,7 @@ def __init__( domain_mods, loss_coefs, contrastive_loss, + use_normalized_constrastive, ) super().__init__( diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index ee6240c1..7ba5dbf5 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -316,10 +316,13 @@ def contrastive_loss_bayesian( loss_output = contrastive_fn( z1 * coef[0] * coef[1], z2 * coef[0] * coef[1] ) + loss_output_no_norm = contrastive_fn(z1, z2) + losses[loss_name] = loss_output.loss metrics.update( {f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()} ) + metrics[f"unnorm_{loss_name}"] = loss_output_no_norm.loss losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean() losses.update(metrics) @@ -763,6 +766,7 @@ def __init__( domain_mods: dict[str, DomainModule], loss_coefs: BroadcastLossCoefs, contrastive_fn: ContrastiveLossType, + use_normalized_constrastive: bool = True, ): """ Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian @@ -775,6 +779,8 @@ def __init__( loss_coefs (`BroadcastLossCoefs`): loss coefficients contrastive_fn (`ContrastiveLossType`): the contrastive function to use in contrastive loss + use_normalized_constrastive (`bool`): whether to use the normalized cont + loss by the precision coefs """ super().__init__() @@ -795,6 +801,8 @@ def __init__( Contrastive loss to use. """ + self.use_normalized_constrastive = use_normalized_constrastive + def contrastive_loss( self, latent_domains: LatentsDomainGroupsT ) -> dict[str, torch.Tensor]: @@ -807,9 +815,11 @@ def contrastive_loss( Returns: `dict[str, torch.Tensor]`: a dict of metrics. """ - return contrastive_loss_bayesian( - self.gw_mod, latent_domains, self.contrastive_fn - ) + if self.use_normalized_constrastive: + return contrastive_loss_bayesian( + self.gw_mod, latent_domains, self.contrastive_fn + ) + return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) def broadcast_loss( self, latent_domains: LatentsDomainGroupsT