From 6ff207e168ba785fbd5dc7161f6cba53509f023d Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 23 Feb 2024 09:35:04 +0000 Subject: [PATCH] Add option to learn logit scale --- shimmer/modules/contrastive_loss.py | 18 ++++++++++++------ shimmer/modules/global_workspace.py | 10 +++++++--- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/shimmer/modules/contrastive_loss.py b/shimmer/modules/contrastive_loss.py index 1c10cfa3..f0a897c9 100644 --- a/shimmer/modules/contrastive_loss.py +++ b/shimmer/modules/contrastive_loss.py @@ -61,16 +61,19 @@ def contrastive_loss_with_uncertainty( class ContrastiveLoss(torch.nn.Module): - logit_scale: torch.Tensor - def __init__( self, logit_scale: torch.Tensor, reduction: Literal["mean", "sum", "none"] = "mean", + learn_logit_scale: bool = False, ) -> None: super().__init__() - self.register_buffer("logit_scale", logit_scale) + if learn_logit_scale: + self.logit_scale = torch.nn.Parameter(self.logit_scale) + else: + self.register_buffer("logit_scale", logit_scale) + self.learn_logit_scale = learn_logit_scale self.reduction: Literal["mean", "sum", "none"] = reduction def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput: @@ -81,16 +84,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput: class ContrastiveLossWithUncertainty(torch.nn.Module): - logit_scale: torch.Tensor - def __init__( self, logit_scale: torch.Tensor, reduction: Literal["mean", "sum", "none"] = "mean", + learn_logit_scale: bool = False, ) -> None: super().__init__() - self.register_buffer("logit_scale", logit_scale) + if learn_logit_scale: + self.logit_scale = torch.nn.Parameter(self.logit_scale) + else: + self.register_buffer("logit_scale", logit_scale) + self.learn_logit_scale = learn_logit_scale self.reduction: Literal["mean", "sum", "none"] = reduction def forward( diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index c75bee05..fbbea561 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -275,6 +275,7 @@ def __init__( optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, + learn_logit_scale: bool = False, ) -> None: gw_mod = GWModule(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) @@ -283,7 +284,7 @@ def __init__( gw_mod, domain_mods, coef_buffers, - ContrastiveLoss(torch.tensor([1 / 0.07]).log(), "mean"), + ContrastiveLoss(torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale), ) super().__init__( @@ -308,6 +309,7 @@ def __init__( optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, + learn_logit_scale: bool = False, ) -> None: gw_mod = VariationalGWModule(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) @@ -319,7 +321,7 @@ def __init__( domain_mods, coef_buffers, var_contrastive_fn=ContrastiveLossWithUncertainty( - torch.tensor([1]).log(), "mean" + torch.tensor([1]).log(), "mean", learn_logit_scale ), ) else: @@ -327,7 +329,9 @@ def __init__( gw_mod, domain_mods, coef_buffers, - contrastive_fn=ContrastiveLoss(torch.tensor([1]).log(), "mean"), + contrastive_fn=ContrastiveLoss( + torch.tensor([1]).log(), "mean", learn_logit_scale + ), ) super().__init__(