Skip to content

Commit

Permalink
Add option to learn logit scale
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Feb 23, 2024
1 parent 40017ac commit 6ff207e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
18 changes: 12 additions & 6 deletions shimmer/modules/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,19 @@ def contrastive_loss_with_uncertainty(


class ContrastiveLoss(torch.nn.Module):
logit_scale: torch.Tensor

def __init__(
self,
logit_scale: torch.Tensor,
reduction: Literal["mean", "sum", "none"] = "mean",
learn_logit_scale: bool = False,
) -> None:
super().__init__()

self.register_buffer("logit_scale", logit_scale)
if learn_logit_scale:
self.logit_scale = torch.nn.Parameter(self.logit_scale)
else:
self.register_buffer("logit_scale", logit_scale)
self.learn_logit_scale = learn_logit_scale
self.reduction: Literal["mean", "sum", "none"] = reduction

def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput:
Expand All @@ -81,16 +84,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput:


class ContrastiveLossWithUncertainty(torch.nn.Module):
logit_scale: torch.Tensor

def __init__(
self,
logit_scale: torch.Tensor,
reduction: Literal["mean", "sum", "none"] = "mean",
learn_logit_scale: bool = False,
) -> None:
super().__init__()

self.register_buffer("logit_scale", logit_scale)
if learn_logit_scale:
self.logit_scale = torch.nn.Parameter(self.logit_scale)
else:
self.register_buffer("logit_scale", logit_scale)
self.learn_logit_scale = learn_logit_scale
self.reduction: Literal["mean", "sum", "none"] = reduction

def forward(
Expand Down
10 changes: 7 additions & 3 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def __init__(
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
) -> None:
gw_mod = GWModule(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)
Expand All @@ -283,7 +284,7 @@ def __init__(
gw_mod,
domain_mods,
coef_buffers,
ContrastiveLoss(torch.tensor([1 / 0.07]).log(), "mean"),
ContrastiveLoss(torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale),
)

super().__init__(
Expand All @@ -308,6 +309,7 @@ def __init__(
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
) -> None:
gw_mod = VariationalGWModule(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)
Expand All @@ -319,15 +321,17 @@ def __init__(
domain_mods,
coef_buffers,
var_contrastive_fn=ContrastiveLossWithUncertainty(
torch.tensor([1]).log(), "mean"
torch.tensor([1]).log(), "mean", learn_logit_scale
),
)
else:
loss_mod = VariationalGWLosses(
gw_mod,
domain_mods,
coef_buffers,
contrastive_fn=ContrastiveLoss(torch.tensor([1]).log(), "mean"),
contrastive_fn=ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
),
)

super().__init__(
Expand Down

0 comments on commit 6ff207e

Please sign in to comment.