From 2ddda1fdbd2702948ffb9733c0a55f78fbe9d0db Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Fri, 23 Feb 2024 11:12:37 +0000 Subject: [PATCH] Can pass contrastive loss as param --- shimmer/modules/global_workspace.py | 30 ++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index fbbea561..f71fe1a5 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -12,6 +12,7 @@ ContrastiveLoss, ContrastiveLossType, ContrastiveLossWithUncertainty, + VarContrastiveLossType, ) from shimmer.modules.dict_buffer import DictBuffer from shimmer.modules.domain import DomainModule @@ -276,15 +277,20 @@ def __init__( optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, + contrastive_loss: ContrastiveLossType | None = None, ) -> None: gw_mod = GWModule(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) coef_buffers = DictBuffer(loss_coefs) + if contrastive_loss is None: + contrastive_loss = ContrastiveLoss( + torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale + ) loss_mod = GWLosses( gw_mod, domain_mods, coef_buffers, - ContrastiveLoss(torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale), + contrastive_loss, ) super().__init__( @@ -305,33 +311,39 @@ def __init__( gw_interfaces: Mapping[str, GWInterfaceBase], workspace_dim: int, loss_coefs: Mapping[str, torch.Tensor], - var_contrastive_loss: bool = False, + use_var_contrastive_loss: bool = False, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, learn_logit_scale: bool = False, + contrastive_loss: ContrastiveLossType | None = None, + var_contrastive_loss: VarContrastiveLossType | None = None, ) -> None: gw_mod = VariationalGWModule(gw_interfaces, workspace_dim) domain_mods = freeze_domain_modules(domain_mods) coef_buffers = DictBuffer(loss_coefs) - if var_contrastive_loss: + if use_var_contrastive_loss: + if var_contrastive_loss is None: + var_contrastive_loss = ContrastiveLossWithUncertainty( + torch.tensor([1]).log(), "mean", learn_logit_scale + ) loss_mod = VariationalGWLosses( gw_mod, domain_mods, coef_buffers, - var_contrastive_fn=ContrastiveLossWithUncertainty( - torch.tensor([1]).log(), "mean", learn_logit_scale - ), + var_contrastive_fn=var_contrastive_loss, ) else: + if contrastive_loss is None: + contrastive_loss = ContrastiveLoss( + torch.tensor([1]).log(), "mean", learn_logit_scale + ) loss_mod = VariationalGWLosses( gw_mod, domain_mods, coef_buffers, - contrastive_fn=ContrastiveLoss( - torch.tensor([1]).log(), "mean", learn_logit_scale - ), + contrastive_fn=contrastive_loss, ) super().__init__(