Skip to content

Commit

Permalink
Fix wrong torch.nn.Module import
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 6, 2024
1 parent a2a0b20 commit 2c0861c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
16 changes: 8 additions & 8 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,8 @@ def __init__(
gw_mod = GWModule(
domain_mods,
workspace_dim,
gw_encoders, # type: ignore
gw_decoders, # type: ignore
gw_encoders,
gw_decoders,
)
if contrastive_loss is None:
contrastive_loss = ContrastiveLoss(
Expand Down Expand Up @@ -626,8 +626,8 @@ def __init__(
gw_mod = VariationalGWModule(
domain_mods,
workspace_dim,
gw_encoders, # type: ignore
gw_decoders, # type: ignore
gw_encoders,
gw_decoders,
)

if use_var_contrastive_loss:
Expand Down Expand Up @@ -707,8 +707,8 @@ def __init__(
gw_mod = GWModuleFusion(
domain_mods,
workspace_dim,
gw_encoders, # type: ignore
gw_decoders, # type: ignore
gw_encoders,
gw_decoders,
)

if contrastive_loss is None:
Expand Down Expand Up @@ -772,8 +772,8 @@ def pretrained_global_workspace(
gw_mod = GWModule(
domain_mods,
workspace_dim,
gw_encoders, # type: ignore
gw_decoders, # type: ignore
gw_encoders,
gw_decoders,
)
loss_mod = GWLosses(
gw_mod,
Expand Down
20 changes: 10 additions & 10 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Iterable, Mapping

import torch
from torch import Module, nn
from torch import nn

from shimmer.modules.domain import DomainModule
from shimmer.modules.vae import reparameterize
Expand Down Expand Up @@ -305,14 +305,14 @@ def cycle(self, x: LatentsDomainGroupT, through: str) -> LatentsDomainGroupDT:


class GWModule(GWModuleBase):
"""GW Module. Implements `GWModuleBase`."""
"""GW nn.Module. Implements `GWModuleBase`."""

def __init__(
self,
domain_modules: Mapping[str, DomainModule],
workspace_dim: int,
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
gw_encoders: Mapping[str, nn.Module],
gw_decoders: Mapping[str, nn.Module],
) -> None:
"""Initializes the GWModule.
Expand All @@ -328,10 +328,10 @@ def __init__(
"""
super().__init__(domain_modules, workspace_dim)

self.gw_encoders = nn.ModuleDict(gw_encoders) # type: ignore
self.gw_encoders = nn.ModuleDict(gw_encoders)
"""The module's encoders"""

self.gw_decoders = nn.ModuleDict(gw_decoders) # type: ignore
self.gw_decoders = nn.ModuleDict(gw_decoders)
"""The module's decoders"""

def fusion_mechanism(self, x: LatentsDomainGroupT) -> torch.Tensor:
Expand Down Expand Up @@ -414,8 +414,8 @@ def __init__(
self,
domain_modules: Mapping[str, DomainModule],
workspace_dim: int,
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
gw_encoders: Mapping[str, nn.Module],
gw_decoders: Mapping[str, nn.Module],
) -> None:
"""Initializes the VariationalGWModule.
Expand All @@ -431,10 +431,10 @@ def __init__(
"""
super().__init__(domain_modules, workspace_dim)

self.gw_encoders = nn.ModuleDict(gw_encoders) # type: ignore
self.gw_encoders = nn.ModuleDict(gw_encoders)
"""The module's encoders"""

self.gw_decoders = nn.ModuleDict(gw_decoders) # type: ignore
self.gw_decoders = nn.ModuleDict(gw_decoders)
"""The module's decoders"""

def fusion_mechanism(self, x: LatentsDomainGroupT) -> torch.Tensor:
Expand Down

0 comments on commit 2c0861c

Please sign in to comment.