Skip to content

Commit

Permalink
Rename gw_latent_dim to workspace_dim.
Browse files Browse the repository at this point in the history
Add workspace_dim attribute in Global Workspace. Fixes #1
  • Loading branch information
bdvllrs committed Jan 22, 2024
1 parent 4790591 commit 55289de
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 25 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ from shimmer import GWInterface
my_domain = MyDomain()
my_domain_gw_interface = GWInterface(
my_domain,
gw_latent_dim=12, # latent dim of the global workspace
workspace_dim=12, # latent dim of the global workspace
encoder_hidden_dim=32, # hidden dimension for the GW encoder
encoder_n_layers=3, # n layers to use for the GW encoder
decoder_hidden_dim=32, # hidden dimension for the GW decoder
Expand Down
18 changes: 11 additions & 7 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from shimmer.modules.dict_buffer import DictBuffer
from shimmer.modules.domain import DomainModule
from shimmer.modules.gw_module import (DeterministicGWModule, GWInterface,
from shimmer.modules.gw_module import (BaseGWInterface, DeterministicGWModule,
GWModule, VariationalGWModule)
from shimmer.modules.losses import (DeterministicGWLosses, GWLosses, LatentsT,
VariationalGWLosses)
Expand Down Expand Up @@ -60,6 +60,10 @@ def __init__(
if scheduler_args is not None:
self.scheduler_args.update(scheduler_args)

@property
def workspace_dim(self):
return self.gw_mod.workspace_dim

def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
return self.gw_mod.encode(x)

Expand Down Expand Up @@ -264,14 +268,14 @@ class GlobalWorkspace(GlobalWorkspaceBase):
def __init__(
self,
domain_mods: Mapping[str, DomainModule],
gw_interfaces: Mapping[str, GWInterface],
gw_latent_dim: int,
gw_interfaces: Mapping[str, BaseGWInterface],
workspace_dim: int,
loss_coefs: dict[str, torch.Tensor],
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
) -> None:
gw_mod = DeterministicGWModule(gw_interfaces, gw_latent_dim)
gw_mod = DeterministicGWModule(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)
coef_buffers = DictBuffer(loss_coefs)
loss_mod = DeterministicGWLosses(gw_mod, domain_mods, coef_buffers)
Expand All @@ -291,15 +295,15 @@ class VariationalGlobalWorkspace(GlobalWorkspaceBase):
def __init__(
self,
domain_mods: Mapping[str, DomainModule],
gw_interfaces: Mapping[str, GWInterface],
gw_latent_dim: int,
gw_interfaces: Mapping[str, BaseGWInterface],
workspace_dim: int,
loss_coefs: dict[str, torch.Tensor],
var_contrastive_loss: bool = False,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
) -> None:
gw_mod = VariationalGWModule(gw_interfaces, gw_latent_dim)
gw_mod = VariationalGWModule(gw_interfaces, workspace_dim)
domain_mods = freeze_domain_modules(domain_mods)
coef_buffers = DictBuffer(loss_coefs)
loss_mod = VariationalGWLosses(
Expand Down
24 changes: 12 additions & 12 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

class BaseGWInterface(nn.Module, ABC):
def __init__(
self, domain_module: DomainModule, gw_latent_dim: int
self, domain_module: DomainModule, workspace_dim: int
) -> None:
super().__init__()
self.domain_module = domain_module
self.gw_latent_dim = gw_latent_dim
self.workspace_dim = workspace_dim

@abstractmethod
def encode(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -99,14 +99,14 @@ def decode(self, z: torch.Tensor) -> torch.Tensor:

class GWModule(nn.Module, ABC):
def __init__(
self, gw_interfaces: Mapping[str, BaseGWInterface], gw_latent_dim: int
self, gw_interfaces: Mapping[str, BaseGWInterface], workspace_dim: int
) -> None:
super().__init__()
# casting for LSP autocompletion
self.gw_interfaces = cast(
dict[str, BaseGWInterface], nn.ModuleDict(gw_interfaces)
)
self.latent_dim = gw_latent_dim
self.workspace_dim = workspace_dim

def on_before_gw_encode_dcy(
self, x: Mapping[str, torch.Tensor]
Expand Down Expand Up @@ -238,22 +238,22 @@ class GWInterface(BaseGWInterface):
def __init__(
self,
domain_module: DomainModule,
gw_latent_dim: int,
workspace_dim: int,
encoder_hidden_dim: int,
encoder_n_layers: int,
decoder_hidden_dim: int,
decoder_n_layers: int,
) -> None:
super().__init__(domain_module, gw_latent_dim)
super().__init__(domain_module, workspace_dim)

self.encoder = GWEncoder(
domain_module.latent_dim,
encoder_hidden_dim,
gw_latent_dim,
workspace_dim,
encoder_n_layers,
)
self.decoder = GWDecoder(
gw_latent_dim,
workspace_dim,
decoder_hidden_dim,
domain_module.latent_dim,
decoder_n_layers,
Expand Down Expand Up @@ -313,22 +313,22 @@ class VariationalGWInterface(BaseGWInterface):
def __init__(
self,
domain_module: DomainModule,
gw_latent_dim: int,
workspace_dim: int,
encoder_hidden_dim: int,
encoder_n_layers: int,
decoder_hidden_dim: int,
decoder_n_layers: int,
) -> None:
super().__init__(domain_module, gw_latent_dim)
super().__init__(domain_module, workspace_dim)

self.encoder = VariationalGWEncoder(
domain_module.latent_dim,
encoder_hidden_dim,
gw_latent_dim,
workspace_dim,
encoder_n_layers,
)
self.decoder = GWDecoder(
gw_latent_dim,
workspace_dim,
decoder_hidden_dim,
domain_module.latent_dim,
decoder_n_layers,
Expand Down
10 changes: 5 additions & 5 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,28 @@ def test_training():
"a": DummyDomainModule(latent_dim=128),
}

gw_latent_dim = 16
workspace_dim = 16

gw_interfaces = {
"v": GWInterface(
domains["v"],
gw_latent_dim=gw_latent_dim,
workspace_dim=workspace_dim,
encoder_hidden_dim=64,
encoder_n_layers=1,
decoder_hidden_dim=64,
decoder_n_layers=1,
),
"t": GWInterface(
domains["t"],
gw_latent_dim=gw_latent_dim,
workspace_dim=workspace_dim,
encoder_hidden_dim=64,
encoder_n_layers=1,
decoder_hidden_dim=64,
decoder_n_layers=1,
),
"a": GWInterface(
domains["a"],
gw_latent_dim=gw_latent_dim,
workspace_dim=workspace_dim,
encoder_hidden_dim=64,
encoder_n_layers=1,
decoder_hidden_dim=64,
Expand All @@ -48,7 +48,7 @@ def test_training():
gw = GlobalWorkspace(
domains,
gw_interfaces,
gw_latent_dim=16,
workspace_dim=16,
loss_coefs={},
)

Expand Down

0 comments on commit 55289de

Please sign in to comment.