Skip to content

Commit

Permalink
Update test to remove gw_interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 6, 2024
1 parent 57edb4e commit a8015ac
Showing 1 changed file with 40 additions and 24 deletions.
64 changes: 40 additions & 24 deletions tests/test_training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch.utils.data
from utils import DummyData, DummyDataset, DummyDomainModule

from shimmer import GlobalWorkspace, GWInterface
from shimmer import GlobalWorkspace, GWDecoder, GWEncoder


def test_training():
Expand All @@ -18,36 +18,52 @@ def test_training():

workspace_dim = 16

gw_interfaces = {
"v": GWInterface(
domains["v"],
workspace_dim=workspace_dim,
encoder_hidden_dim=64,
encoder_n_layers=1,
decoder_hidden_dim=64,
decoder_n_layers=1,
gw_encoders = {
"v": GWEncoder(
domains["v"].latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
n_layers=1,
),
"t": GWInterface(
domains["t"],
workspace_dim=workspace_dim,
encoder_hidden_dim=64,
encoder_n_layers=1,
decoder_hidden_dim=64,
decoder_n_layers=1,
"t": GWEncoder(
domains["t"].latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
n_layers=1,
),
"a": GWInterface(
domains["a"],
workspace_dim=workspace_dim,
encoder_hidden_dim=64,
encoder_n_layers=1,
decoder_hidden_dim=64,
decoder_n_layers=1,
"a": GWEncoder(
domains["a"].latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
n_layers=1,
),
}

gw_decoders = {
"v": GWDecoder(
workspace_dim,
hidden_dim=64,
out_dim=domains["v"].latent_dim,
n_layers=1,
),
"t": GWDecoder(
workspace_dim,
hidden_dim=64,
out_dim=domains["t"].latent_dim,
n_layers=1,
),
"a": GWDecoder(
workspace_dim,
hidden_dim=64,
out_dim=domains["a"].latent_dim,
n_layers=1,
),
}

gw = GlobalWorkspace(
domains,
gw_interfaces,
gw_encoders,
gw_decoders,
workspace_dim=16,
loss_coefs={},
)
Expand Down

0 comments on commit a8015ac

Please sign in to comment.