Skip to content

Commit

Permalink
Fix ruff formatting error
Browse files Browse the repository at this point in the history
  • Loading branch information
HugoChateauLaurent committed Oct 4, 2024
1 parent 68fb04f commit be33403
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,9 @@ def __init__(
"""
domain_mods = freeze_domain_modules(domain_mods)

gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders,
fusion_activation_fn)
gw_mod = GWModule(
domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn
)
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale
Expand Down Expand Up @@ -773,8 +774,9 @@ def __init__(
function to fuse the domains.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders,
fusion_activation_fn)
gw_mod = GWModule(
domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn
)

if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
Expand Down Expand Up @@ -844,8 +846,9 @@ def pretrained_global_workspace(
`TypeError`: if loaded type is not `GlobalWorkspace`.
"""
domain_mods = freeze_domain_modules(domain_mods)
gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders,
fusion_activation_fn)
gw_mod = GWModule(
domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn
)
selection_mod = SingleDomainSelection()
loss_mod = GWLosses2Domains(
gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn
Expand Down

0 comments on commit be33403

Please sign in to comment.