From 2c0861c50de2275b769d21b97a20750cb8d0fb08 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Wed, 6 Mar 2024 13:54:48 +0000 Subject: [PATCH] Fix wrong torch.nn.Module import --- shimmer/modules/global_workspace.py | 16 ++++++++-------- shimmer/modules/gw_module.py | 20 ++++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index a8452dbe..138d8c1f 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -548,8 +548,8 @@ def __init__( gw_mod = GWModule( domain_mods, workspace_dim, - gw_encoders, # type: ignore - gw_decoders, # type: ignore + gw_encoders, + gw_decoders, ) if contrastive_loss is None: contrastive_loss = ContrastiveLoss( @@ -626,8 +626,8 @@ def __init__( gw_mod = VariationalGWModule( domain_mods, workspace_dim, - gw_encoders, # type: ignore - gw_decoders, # type: ignore + gw_encoders, + gw_decoders, ) if use_var_contrastive_loss: @@ -707,8 +707,8 @@ def __init__( gw_mod = GWModuleFusion( domain_mods, workspace_dim, - gw_encoders, # type: ignore - gw_decoders, # type: ignore + gw_encoders, + gw_decoders, ) if contrastive_loss is None: @@ -772,8 +772,8 @@ def pretrained_global_workspace( gw_mod = GWModule( domain_mods, workspace_dim, - gw_encoders, # type: ignore - gw_decoders, # type: ignore + gw_encoders, + gw_decoders, ) loss_mod = GWLosses( gw_mod, diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 0b653976..4aebf7b6 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -2,7 +2,7 @@ from collections.abc import Iterable, Mapping import torch -from torch import Module, nn +from torch import nn from shimmer.modules.domain import DomainModule from shimmer.modules.vae import reparameterize @@ -305,14 +305,14 @@ def cycle(self, x: LatentsDomainGroupT, through: str) -> LatentsDomainGroupDT: class GWModule(GWModuleBase): - """GW Module. Implements `GWModuleBase`.""" + """GW nn.Module. Implements `GWModuleBase`.""" def __init__( self, domain_modules: Mapping[str, DomainModule], workspace_dim: int, - gw_encoders: Mapping[str, Module], - gw_decoders: Mapping[str, Module], + gw_encoders: Mapping[str, nn.Module], + gw_decoders: Mapping[str, nn.Module], ) -> None: """Initializes the GWModule. @@ -328,10 +328,10 @@ def __init__( """ super().__init__(domain_modules, workspace_dim) - self.gw_encoders = nn.ModuleDict(gw_encoders) # type: ignore + self.gw_encoders = nn.ModuleDict(gw_encoders) """The module's encoders""" - self.gw_decoders = nn.ModuleDict(gw_decoders) # type: ignore + self.gw_decoders = nn.ModuleDict(gw_decoders) """The module's decoders""" def fusion_mechanism(self, x: LatentsDomainGroupT) -> torch.Tensor: @@ -414,8 +414,8 @@ def __init__( self, domain_modules: Mapping[str, DomainModule], workspace_dim: int, - gw_encoders: Mapping[str, Module], - gw_decoders: Mapping[str, Module], + gw_encoders: Mapping[str, nn.Module], + gw_decoders: Mapping[str, nn.Module], ) -> None: """Initializes the VariationalGWModule. @@ -431,10 +431,10 @@ def __init__( """ super().__init__(domain_modules, workspace_dim) - self.gw_encoders = nn.ModuleDict(gw_encoders) # type: ignore + self.gw_encoders = nn.ModuleDict(gw_encoders) """The module's encoders""" - self.gw_decoders = nn.ModuleDict(gw_decoders) # type: ignore + self.gw_decoders = nn.ModuleDict(gw_decoders) """The module's decoders""" def fusion_mechanism(self, x: LatentsDomainGroupT) -> torch.Tensor: