diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 94ae6ce8..24d18fab 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -695,8 +695,9 @@ def __init__( """ domain_mods = freeze_domain_modules(domain_mods) - gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders, - fusion_activation_fn) + gw_mod = GWModule( + domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn + ) if contrastive_loss is None: contrastive_loss = ContrastiveLoss( torch.tensor([1 / 0.07]).log(), "mean", learn_logit_scale @@ -773,8 +774,9 @@ def __init__( function to fuse the domains. """ domain_mods = freeze_domain_modules(domain_mods) - gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders, - fusion_activation_fn) + gw_mod = GWModule( + domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn + ) if contrastive_loss is None: contrastive_loss = ContrastiveLoss( @@ -844,8 +846,9 @@ def pretrained_global_workspace( `TypeError`: if loaded type is not `GlobalWorkspace`. """ domain_mods = freeze_domain_modules(domain_mods) - gw_mod = GWModule(domain_mods, workspace_dim, gw_encoders, gw_decoders, - fusion_activation_fn) + gw_mod = GWModule( + domain_mods, workspace_dim, gw_encoders, gw_decoders, fusion_activation_fn + ) selection_mod = SingleDomainSelection() loss_mod = GWLosses2Domains( gw_mod, selection_mod, domain_mods, loss_coefs, contrastive_fn