diff --git a/tests/test_training.py b/tests/test_training.py index fa3648ba..2668fa27 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -1,7 +1,7 @@ import torch.utils.data from utils import DummyData, DummyDataset, DummyDomainModule -from shimmer import GlobalWorkspace, GWInterface +from shimmer import GlobalWorkspace, GWDecoder, GWEncoder def test_training(): @@ -18,36 +18,52 @@ def test_training(): workspace_dim = 16 - gw_interfaces = { - "v": GWInterface( - domains["v"], - workspace_dim=workspace_dim, - encoder_hidden_dim=64, - encoder_n_layers=1, - decoder_hidden_dim=64, - decoder_n_layers=1, + gw_encoders = { + "v": GWEncoder( + domains["v"].latent_dim, + hidden_dim=64, + out_dim=workspace_dim, + n_layers=1, ), - "t": GWInterface( - domains["t"], - workspace_dim=workspace_dim, - encoder_hidden_dim=64, - encoder_n_layers=1, - decoder_hidden_dim=64, - decoder_n_layers=1, + "t": GWEncoder( + domains["t"].latent_dim, + hidden_dim=64, + out_dim=workspace_dim, + n_layers=1, ), - "a": GWInterface( - domains["a"], - workspace_dim=workspace_dim, - encoder_hidden_dim=64, - encoder_n_layers=1, - decoder_hidden_dim=64, - decoder_n_layers=1, + "a": GWEncoder( + domains["a"].latent_dim, + hidden_dim=64, + out_dim=workspace_dim, + n_layers=1, + ), + } + + gw_decoders = { + "v": GWDecoder( + workspace_dim, + hidden_dim=64, + out_dim=domains["v"].latent_dim, + n_layers=1, + ), + "t": GWDecoder( + workspace_dim, + hidden_dim=64, + out_dim=domains["t"].latent_dim, + n_layers=1, + ), + "a": GWDecoder( + workspace_dim, + hidden_dim=64, + out_dim=domains["a"].latent_dim, + n_layers=1, ), } gw = GlobalWorkspace( domains, - gw_interfaces, + gw_encoders, + gw_decoders, workspace_dim=16, loss_coefs={}, )