Skip to content

Commit

Permalink
Add temperature params to improve Bayesian training (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored May 21, 2024
1 parent c0eb83b commit cd525cb
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
4 changes: 4 additions & 0 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,7 @@ def __init__(
learn_logit_scale: bool = False,
use_normalized_constrastive: bool = True,
contrastive_loss: ContrastiveLossType | None = None,
precision_softmax_temp: float = 0.01,
) -> None:
"""
Initializes a Global Workspace
Expand Down Expand Up @@ -712,6 +713,8 @@ def __init__(
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
precision_softmax_temp (`float`): temperature to use in softmax of
precision
"""
domain_mods = freeze_domain_modules(domain_mods)

Expand All @@ -722,6 +725,7 @@ def __init__(
gw_decoders,
sensitivity_selection,
sensitivity_precision,
precision_softmax_temp,
)

selection_mod = FixedSharedSelection()
Expand Down
8 changes: 7 additions & 1 deletion shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __init__(
gw_decoders: Mapping[str, nn.Module],
sensitivity_selection: float = 1,
sensitivity_precision: float = 1,
precision_softmax_temp: float = 0.01,
) -> None:
"""
Initializes the GWModuleBayesian.
Expand All @@ -353,6 +354,8 @@ def __init__(
GW representation to a unimodal latent representation.
sensitivity_selection (`float`): sensivity coef $c'_1$
sensitivity_precision (`float`): sensitivity coef $c'_2$
precision_softmax_temp (`float`): temperature to use in softmax of
precision
"""
super().__init__(domain_modules, workspace_dim, gw_encoders, gw_decoders)

Expand All @@ -366,6 +369,7 @@ def __init__(

self.sensitivity_selection = sensitivity_selection
self.sensitivity_precision = sensitivity_precision
self.precision_softmax_temp = precision_softmax_temp

def get_precision(self, domain: str, x: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -437,7 +441,9 @@ def fuse(
domains.append(x[domain])
combined_scores = compute_fusion_scores(
torch.stack(scores).unsqueeze(-1),
torch.softmax(torch.stack(precisions), dim=0),
torch.softmax(
torch.tanh(torch.stack(precisions)) * self.precision_softmax_temp, dim=0
),
self.sensitivity_selection,
self.sensitivity_precision,
)
Expand Down
10 changes: 6 additions & 4 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,14 @@ def contrastive_loss_bayesian(
loss_name = f"contrastive_{domain1_name}_and_{domain2_name}"
z2 = gw_mod.encode({domain2_name: domain2})[domain2_name]
z2_precision = gw_mod.get_precision(domain2_name, domain2)
coef = torch.stack([z1_precision, z2_precision]).softmax(dim=0)
loss_output = contrastive_fn(
z1 * coef[0] * coef[1], z2 * coef[0] * coef[1]
coef = torch.softmax(
gw_mod.precision_softmax_temp
* torch.stack([z1_precision, z2_precision]),
dim=0,
)
norm = torch.sqrt(coef[0] * coef[1])
loss_output = contrastive_fn(z1 * norm, z2 * norm)
loss_output_no_norm = contrastive_fn(z1, z2)

losses[loss_name] = loss_output.loss
metrics.update(
{f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
Expand Down

0 comments on commit cd525cb

Please sign in to comment.