diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index d1b6bda5..93d71350 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -137,7 +137,7 @@ def workspace_dim(self) -> int: return self.gw_mod.workspace_dim def encode_and_fuse( - self, x: LatentsDomainGroupT, selection_scores: Mapping[str, torch.Tensor] + self, x: LatentsDomainGroupT, selection_module: SelectionBase ) -> torch.Tensor: """ Encode latent representations into the GW representation. @@ -154,7 +154,7 @@ def encode_and_fuse( Returns: `torch.Tensor`: the GW representations. """ - return self.gw_mod.encode_and_fuse(x, selection_scores) + return self.gw_mod.encode_and_fuse(x, selection_module) def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT: """ @@ -226,8 +226,9 @@ def batch_gw_states( if len(domains) > 1: continue domain_name = list(domains)[0] - scores = self.selection_mod(latents) - z = self.gw_mod.encode_and_fuse(latents, scores) + z = self.gw_mod.encode_and_fuse( + latents, selection_module=self.selection_mod + ) predictions[domain_name] = z return predictions diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 3afe4cdd..bd78b4b7 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -5,6 +5,7 @@ from torch import nn from shimmer.modules.domain import DomainModule +from shimmer.modules.selection import SelectionBase from shimmer.modules.vae import reparameterize from shimmer.types import LatentsDomainGroupDT, LatentsDomainGroupT @@ -228,7 +229,7 @@ def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT: ... def encode_and_fuse( - self, x: LatentsDomainGroupT, selection_scores: Mapping[str, torch.Tensor] + self, x: LatentsDomainGroupT, selection_module: SelectionBase ) -> torch.Tensor: """ Encode the latent representation infos to the final GW representation. @@ -242,7 +243,9 @@ def encode_and_fuse( Returns: `torch.Tensor`: The merged representation. """ - return self.fuse(self.encode(x), selection_scores) + encodings = self.encode(x) + selection_scores = selection_module(x, encodings) + return self.fuse(encodings, selection_scores) @abstractmethod def decode( @@ -312,9 +315,11 @@ def fuse( def encode_and_fuse( self, x: LatentsDomainGroupT, - selection_scores: Mapping[str, torch.Tensor] | None = None, + selection_module: SelectionBase | None = None, ) -> torch.Tensor: - return self.fuse(self.encode(x), selection_scores) + encodings = self.encode(x) + selection_scores = selection_module(x, encodings) + return self.fuse(encodings, selection_scores) def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT: """ @@ -414,11 +419,11 @@ def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT: } def encode_and_fuse( - self, - x: LatentsDomainGroupT, - selection_scores: Mapping[str, torch.Tensor] | None = None, + self, x: LatentsDomainGroupT, selection_module=SelectionBase ) -> torch.Tensor: - return self.fuse(self.encode(x), selection_scores) + encodings = self.encode(x) + selection_scores = selection_module(x, encodings) + return self.fuse(encodings, selection_scores) def encoded_distribution( self, x: LatentsDomainGroupT diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index fa38ee9c..992c4124 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -77,10 +77,9 @@ def demi_cycle_loss( continue domain_name = next(iter(domains)) - selection_scores = selection_mod(latents) domain_mod = domain_mods[domain_name] x_recons = gw_mod.decode( - gw_mod.encode_and_fuse(latents, selection_scores), domains={domain_name} + gw_mod.encode_and_fuse(latents, selection_mod), domains={domain_name} )[domain_name] loss_output = domain_mod.compute_dcy_loss(x_recons, latents[domain_name]) losses[f"demi_cycle_{domain_name}"] = loss_output.loss @@ -125,18 +124,16 @@ def cycle_loss( continue domain_name_source = list(domains_source)[0] - selection_scores_source = selection_mod(latents_source) domain_mod = domain_mods[domain_name_source] - z = gw_mod.encode_and_fuse(latents_source, selection_scores_source) + z = gw_mod.encode_and_fuse(latents_source, selection_mod) for domain_name_target in domain_mods: if domain_name_target == domain_name_source: continue x_pred = gw_mod.decode(z, domains={domain_name_target}) - selection_scores_target = selection_mod(x_pred) x_recons = gw_mod.decode( - gw_mod.encode_and_fuse(x_pred, selection_scores_target), + gw_mod.encode_and_fuse(x_pred, selection_mod), domains={domain_name_source}, ) @@ -192,9 +189,7 @@ def translation_loss( if domain != domain_name_target } - selection_scores = selection_mod(domain_sources) - - z = gw_mod.encode_and_fuse(domain_sources, selection_scores) + z = gw_mod.encode_and_fuse(domain_sources, selection_mod) mod = domain_mods[domain_name_target] domain_source_names = "/".join(domain_sources.keys()) @@ -802,7 +797,7 @@ def broadcast_loss( selection_scores = self.selection_mod(scaled_latents) encoded_latents_for_subset = self.gw_mod.encode_and_fuse( - scaled_latents, selection_scores + scaled_latents, self.selection_mod ) encoded_latents_for_subset = torch.tanh(encoded_latents_for_subset) decoded_latents_for_subset = self.gw_mod.decode( diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index 2c10b2fe..a2b6198e 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -29,7 +29,9 @@ def update_gw_state(self, gw_state: torch.Tensor) -> None: pass @abstractmethod - def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]: + def forward( + self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT + ) -> dict[str, torch.Tensor]: """ Forward pass of the selection method. @@ -49,8 +51,10 @@ def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]: ... # This is just for proper auto-completion... - def __call__(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]: - return super().__call__(domains) + def __call__( + self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT + ) -> dict[str, torch.Tensor]: + return super().__call__(domains, encodings_pre_fusion) class SingleDomainSelection(SelectionBase): @@ -62,7 +66,9 @@ class SingleDomainSelection(SelectionBase): domain. """ - def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]: + def forward( + self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT + ) -> dict[str, torch.Tensor]: """ Forward pass of the module. @@ -114,7 +120,9 @@ def update_gw_state(self, gw_state: torch.Tensor) -> None: """ self.gw_state = gw_state - def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]: + def forward( + self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT + ) -> dict[str, torch.Tensor]: """ Compute keys and queries, match them with dot product and softmax. @@ -173,7 +181,9 @@ def __init__(self, binary_proportion: float, temperature: float): self.binary_proportion = binary_proportion self.temperature = temperature - def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]: + def forward( + self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT + ) -> dict[str, torch.Tensor]: """ randomly draw binary and uniform-then-domain-wise-softmaxed samples according to self.binary_proportion. diff --git a/shimmer/modules/utils.py b/shimmer/modules/utils.py index 35c070f8..87616f79 100644 --- a/shimmer/modules/utils.py +++ b/shimmer/modules/utils.py @@ -30,10 +30,9 @@ def translation( `torch.Tensor`: the translated unimodal representation of the provided domain. """ - selection_scores = selection_mod(x) - return gw_module.decode( - gw_module.encode_and_fuse(x, selection_scores), domains={to} - )[to] + return gw_module.decode(gw_module.encode_and_fuse(x, selection_mod), domains={to})[ + to + ] def translation_with_uncertainty( diff --git a/tests/test_kq_onepass_attention.py b/tests/test_kq_onepass_attention.py index caf1db26..1cc1f10f 100644 --- a/tests/test_kq_onepass_attention.py +++ b/tests/test_kq_onepass_attention.py @@ -13,7 +13,8 @@ def test_single_domain(): attention.update_gw_state(gw_state) single_domain_input = {"v_latents": torch.rand(batch_size, domain_dim)} - attention_scores = attention(single_domain_input) + encodings_pre_fusion = {"v_latents": torch.rand(batch_size, domain_dim)} + attention_scores = attention(single_domain_input, encodings_pre_fusion) expected_scores = torch.ones(batch_size, 1) assert torch.allclose( @@ -33,7 +34,12 @@ def test_multiple_domains_sumis1(): "v_latents": torch.rand(batch_size, domain_dim), "attr": torch.rand(batch_size, domain_dim), } - attention_scores = attention(multiple_domain_input) + encodings_pre_fusion = { + "v_latents": torch.rand(batch_size, domain_dim), + "attr": torch.rand(batch_size, domain_dim), + } + + attention_scores = attention(multiple_domain_input, encodings_pre_fusion) scores_sum = sum( attention_scores[domain].squeeze() for domain in multiple_domain_input @@ -61,7 +67,11 @@ def test_attention_backward(): "attr": torch.rand(batch_size, domain_dim, requires_grad=True), } - attention_scores = attention(domains) + encodings_pre_fusion = { + "v_latents": torch.rand(batch_size, domain_dim, requires_grad=True), + "attr": torch.rand(batch_size, domain_dim, requires_grad=True), + } + attention_scores = attention(domains, encodings_pre_fusion) loss = sum(score.mean() for score in attention_scores.values()) assert isinstance(loss, torch.Tensor) loss.backward() diff --git a/tests/test_random_attention.py b/tests/test_random_attention.py index 03630c66..c887898e 100644 --- a/tests/test_random_attention.py +++ b/tests/test_random_attention.py @@ -15,7 +15,13 @@ def test_multiple_domains(): "v_latents": torch.rand(batch_size, domain_dim), "attr": torch.rand(batch_size, domain_dim), } - selection_scores = selection(multiple_domain_input) + + prefusion_encodings = { + "v_latents": torch.rand(batch_size, domain_dim), + "attr": torch.rand(batch_size, domain_dim), + } + + selection_scores = selection(multiple_domain_input, prefusion_encodings) # Ensure the sum of attention scores across domains equals 1 scores_sum = sum( @@ -42,7 +48,14 @@ def test_three_domains(): "attr": torch.rand(batch_size, domain_dim), "audio": torch.rand(batch_size, domain_dim), } - selection_scores = selection(three_domain_input) + + prefusion_encodings = { + "v_latents": torch.rand(batch_size, domain_dim), + "attr": torch.rand(batch_size, domain_dim), + "audio": torch.rand(batch_size, domain_dim), + } + + selection_scores = selection(three_domain_input, prefusion_encodings) # Ensure that the shape of the selection scores matches the input domains for domain in three_domain_input: @@ -83,7 +96,14 @@ def test_binary_scores_xor_check_for_multiple_proportions(): "attr": torch.rand(batch_size, domain_dim), "audio": torch.rand(batch_size, domain_dim), } - selection_scores = selection(domains_input) + + prefusion_encodings = { + "v_latents": torch.rand(batch_size, domain_dim), + "attr": torch.rand(batch_size, domain_dim), + "audio": torch.rand(batch_size, domain_dim), + } + + selection_scores = selection(domains_input, prefusion_encodings) scores_matrix = torch.cat( [selection_scores[domain] for domain in domains_input], dim=1 diff --git a/tests/test_single_domain_selection.py b/tests/test_single_domain_selection.py index 27e9ec08..90b44abe 100644 --- a/tests/test_single_domain_selection.py +++ b/tests/test_single_domain_selection.py @@ -8,8 +8,9 @@ def test_selection_1_domain(): bs = 32 domains = {"v": torch.randn(bs, 8)} + prefusion_encodings = {"v": torch.randn(bs, 8)} - selection: dict[str, torch.Tensor] = selection_mod(domains) + selection: dict[str, torch.Tensor] = selection_mod(domains, prefusion_encodings) assert len(selection) == len(domains) assert next(iter(selection.keys())) == "v" @@ -21,8 +22,9 @@ def test_selection_2_domains(): bs = 32 domains = {"v": torch.randn(bs, 8), "t": torch.randn(bs, 12)} + prefusion_encodings = {"v": torch.randn(bs, 8), "t": torch.randn(bs, 12)} - selection: dict[str, torch.Tensor] = selection_mod(domains) + selection: dict[str, torch.Tensor] = selection_mod(domains, prefusion_encodings) assert len(selection) == len(domains) assert ( @@ -39,8 +41,13 @@ def test_selection_3_domains(): "t": torch.randn(bs, 12), "attr": torch.randn(bs, 4), } + prefusion_encodings = { + "v": torch.randn(bs, 8), + "t": torch.randn(bs, 12), + "attr": torch.randn(bs, 4), + } - selection: dict[str, torch.Tensor] = selection_mod(domains) + selection: dict[str, torch.Tensor] = selection_mod(domains, prefusion_encodings) assert len(selection) == len(domains) assert ( diff --git a/tests/test_training.py b/tests/test_training.py index 9cf3d4b8..a96ac1a0 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -2,6 +2,7 @@ from utils import DummyData, DummyDataset, DummyDomainModule from shimmer import GlobalWorkspace, GWDecoder, GWEncoder +from shimmer.modules.selection import SingleDomainSelection def test_training(): @@ -82,11 +83,8 @@ def test_training(): assert isinstance(unimodal_latents["v"], torch.Tensor) assert unimodal_latents["v"].size() == (32, 128) - selection_scores = { - domain: torch.full((batch_size,), 1.0 / len(unimodal_latents)) - for domain in unimodal_latents - } - workspace_latent = gw.encode_and_fuse(unimodal_latents, selection_scores) + selection_module = SingleDomainSelection() + workspace_latent = gw.encode_and_fuse(unimodal_latents, selection_module) assert workspace_latent.size() == (32, 16) diff --git a/tests/utils.py b/tests/utils.py index b975a7c7..69a32f3b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ import torch import torch.utils.data -from shimmer.modules import DomainModule +from shimmer.modules import DomainModule, SelectionBase PROJECT_DIR = Path(__file__).resolve().parent.parent @@ -36,3 +36,8 @@ def encode(self, x: DummyData) -> torch.Tensor: def decode(self, z: torch.Tensor) -> DummyData: return DummyData(vec=z) + + +class DummySelectionModule(SelectionBase): + def forward(self, x): + return {"v_latents": torch.ones(x["v_latents"].shape[0], 1)}