Skip to content

Commit

Permalink
add tests for freezing
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Oct 3, 2024
1 parent 3de6cfb commit 40d2823
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
69 changes: 69 additions & 0 deletions tests/test_freeze_domains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from utils import DummyDomainModuleWithParams

from shimmer import GlobalWorkspace2Domains, GWDecoder, GWEncoder


def test_training():
domains = {
"v": DummyDomainModuleWithParams(latent_dim=128),
"t": DummyDomainModuleWithParams(latent_dim=128),
"a": DummyDomainModuleWithParams(latent_dim=128),
}

domains["a"].unfreeze()

workspace_dim = 16

gw_encoders = {
"v": GWEncoder(
domains["v"].latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
n_layers=1,
),
"t": GWEncoder(
domains["t"].latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
n_layers=1,
),
"a": GWEncoder(
domains["a"].latent_dim,
hidden_dim=64,
out_dim=workspace_dim,
n_layers=1,
),
}

gw_decoders = {
"v": GWDecoder(
workspace_dim,
hidden_dim=64,
out_dim=domains["v"].latent_dim,
n_layers=1,
),
"t": GWDecoder(
workspace_dim,
hidden_dim=64,
out_dim=domains["t"].latent_dim,
n_layers=1,
),
"a": GWDecoder(
workspace_dim,
hidden_dim=64,
out_dim=domains["a"].latent_dim,
n_layers=1,
),
}

gw = GlobalWorkspace2Domains(
domains,
gw_encoders,
gw_decoders,
workspace_dim=16,
loss_coefs={},
)
assert gw.domain_mods["v"].is_frozen
assert not gw.domain_mods["a"].is_frozen
assert not len([p for p in gw.domain_mods["v"].parameters() if p.requires_grad])
assert len([p for p in gw.domain_mods["a"].parameters() if p.requires_grad])
12 changes: 12 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,15 @@ def encode(self, x: DummyData) -> torch.Tensor:

def decode(self, z: torch.Tensor) -> DummyData:
return DummyData(vec=z)


class DummyDomainModuleWithParams(DomainModule):
def __init__(self, latent_dim: int) -> None:
super().__init__(latent_dim)
self.net = torch.nn.Linear(1, 1)

def encode(self, x: DummyData) -> torch.Tensor:
return x.vec

def decode(self, z: torch.Tensor) -> DummyData:
return DummyData(vec=z)

0 comments on commit 40d2823

Please sign in to comment.