Skip to content

Commit

Permalink
Can pass contrastive loss as param
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Feb 23, 2024
1 parent 143c8ea commit 2ddda1f
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ContrastiveLoss,
ContrastiveLossType,
ContrastiveLossWithUncertainty,
VarContrastiveLossType,
)
from shimmer.modules.dict_buffer import DictBuffer
from shimmer.modules.domain import DomainModule
Expand Down Expand Up @@ -276,15 +277,20 @@ def __init__(
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
contrastive_loss: ContrastiveLossType | None = None,
) -> None:
gw_mod = GWModule(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)
coef_buffers = DictBuffer(loss_coefs)
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
)
loss_mod = GWLosses(
gw_mod,
domain_mods,
coef_buffers,
ContrastiveLoss(torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale),
contrastive_loss,
)

super().__init__(
Expand All @@ -305,33 +311,39 @@ def __init__(
gw_interfaces: Mapping[str, GWInterfaceBase],
workspace_dim: int,
loss_coefs: Mapping[str, torch.Tensor],
var_contrastive_loss: bool = False,
use_var_contrastive_loss: bool = False,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
contrastive_loss: ContrastiveLossType | None = None,
var_contrastive_loss: VarContrastiveLossType | None = None,
) -> None:
gw_mod = VariationalGWModule(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)
coef_buffers = DictBuffer(loss_coefs)

if var_contrastive_loss:
if use_var_contrastive_loss:
if var_contrastive_loss is None:
var_contrastive_loss = ContrastiveLossWithUncertainty(
torch.tensor([1]).log(), "mean", learn_logit_scale
)
loss_mod = VariationalGWLosses(
gw_mod,
domain_mods,
coef_buffers,
var_contrastive_fn=ContrastiveLossWithUncertainty(
torch.tensor([1]).log(), "mean", learn_logit_scale
),
var_contrastive_fn=var_contrastive_loss,
)
else:
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
)
loss_mod = VariationalGWLosses(
gw_mod,
domain_mods,
coef_buffers,
contrastive_fn=ContrastiveLoss(
torch.tensor([1]).log(), "mean", learn_logit_scale
),
contrastive_fn=contrastive_loss,
)

super().__init__(
Expand Down

0 comments on commit 2ddda1f

Please sign in to comment.