diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 7aec98ff..852bc1cc 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -684,6 +684,7 @@ def __init__( learn_logit_scale: bool = False, use_normalized_constrastive: bool = True, contrastive_loss: ContrastiveLossType | None = None, + precision_softmax_temp: float = 0.01, ) -> None: """ Initializes a Global Workspace @@ -712,6 +713,8 @@ def __init__( contrastive_loss (`ContrastiveLossType | None`): a contrastive loss function used for alignment. `learn_logit_scale` will not affect custom contrastive losses. + precision_softmax_temp (`float`): temperature to use in softmax of + precision """ domain_mods = freeze_domain_modules(domain_mods) @@ -722,6 +725,7 @@ def __init__( gw_decoders, sensitivity_selection, sensitivity_precision, + precision_softmax_temp, ) selection_mod = FixedSharedSelection() diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 9c1922eb..3dd1b600 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -338,6 +338,7 @@ def __init__( gw_decoders: Mapping[str, nn.Module], sensitivity_selection: float = 1, sensitivity_precision: float = 1, + precision_softmax_temp: float = 0.01, ) -> None: """ Initializes the GWModuleBayesian. @@ -353,6 +354,8 @@ def __init__( GW representation to a unimodal latent representation. sensitivity_selection (`float`): sensivity coef $c'_1$ sensitivity_precision (`float`): sensitivity coef $c'_2$ + precision_softmax_temp (`float`): temperature to use in softmax of + precision """ super().__init__(domain_modules, workspace_dim, gw_encoders, gw_decoders) @@ -366,6 +369,7 @@ def __init__( self.sensitivity_selection = sensitivity_selection self.sensitivity_precision = sensitivity_precision + self.precision_softmax_temp = precision_softmax_temp def get_precision(self, domain: str, x: torch.Tensor) -> torch.Tensor: """ @@ -437,7 +441,9 @@ def fuse( domains.append(x[domain]) combined_scores = compute_fusion_scores( torch.stack(scores).unsqueeze(-1), - torch.softmax(torch.stack(precisions), dim=0), + torch.softmax( + torch.tanh(torch.stack(precisions)) * self.precision_softmax_temp, dim=0 + ), self.sensitivity_selection, self.sensitivity_precision, ) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 7ba5dbf5..08273a93 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -312,12 +312,14 @@ def contrastive_loss_bayesian( loss_name = f"contrastive_{domain1_name}_and_{domain2_name}" z2 = gw_mod.encode({domain2_name: domain2})[domain2_name] z2_precision = gw_mod.get_precision(domain2_name, domain2) - coef = torch.stack([z1_precision, z2_precision]).softmax(dim=0) - loss_output = contrastive_fn( - z1 * coef[0] * coef[1], z2 * coef[0] * coef[1] + coef = torch.softmax( + gw_mod.precision_softmax_temp + * torch.stack([z1_precision, z2_precision]), + dim=0, ) + norm = torch.sqrt(coef[0] * coef[1]) + loss_output = contrastive_fn(z1 * norm, z2 * norm) loss_output_no_norm = contrastive_fn(z1, z2) - losses[loss_name] = loss_output.loss metrics.update( {f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}