Skip to content

Commit

Permalink
Add prefusion encodings to selection base (#41)
Browse files Browse the repository at this point in the history
* added prefusion encodings to selection modules

* fixed errors
  • Loading branch information
larascipio authored Apr 2, 2024
1 parent 6742864 commit 0913906
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 47 deletions.
9 changes: 5 additions & 4 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def workspace_dim(self) -> int:
return self.gw_mod.workspace_dim

def encode_and_fuse(
self, x: LatentsDomainGroupT, selection_scores: Mapping[str, torch.Tensor]
self, x: LatentsDomainGroupT, selection_module: SelectionBase
) -> torch.Tensor:
"""
Encode latent representations into the GW representation.
Expand All @@ -154,7 +154,7 @@ def encode_and_fuse(
Returns:
`torch.Tensor`: the GW representations.
"""
return self.gw_mod.encode_and_fuse(x, selection_scores)
return self.gw_mod.encode_and_fuse(x, selection_module)

def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT:
"""
Expand Down Expand Up @@ -226,8 +226,9 @@ def batch_gw_states(
if len(domains) > 1:
continue
domain_name = list(domains)[0]
scores = self.selection_mod(latents)
z = self.gw_mod.encode_and_fuse(latents, scores)
z = self.gw_mod.encode_and_fuse(
latents, selection_module=self.selection_mod
)
predictions[domain_name] = z
return predictions

Expand Down
21 changes: 13 additions & 8 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import nn

from shimmer.modules.domain import DomainModule
from shimmer.modules.selection import SelectionBase
from shimmer.modules.vae import reparameterize
from shimmer.types import LatentsDomainGroupDT, LatentsDomainGroupT

Expand Down Expand Up @@ -228,7 +229,7 @@ def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT:
...

def encode_and_fuse(
self, x: LatentsDomainGroupT, selection_scores: Mapping[str, torch.Tensor]
self, x: LatentsDomainGroupT, selection_module: SelectionBase
) -> torch.Tensor:
"""
Encode the latent representation infos to the final GW representation.
Expand All @@ -242,7 +243,9 @@ def encode_and_fuse(
Returns:
`torch.Tensor`: The merged representation.
"""
return self.fuse(self.encode(x), selection_scores)
encodings = self.encode(x)
selection_scores = selection_module(x, encodings)
return self.fuse(encodings, selection_scores)

@abstractmethod
def decode(
Expand Down Expand Up @@ -312,9 +315,11 @@ def fuse(
def encode_and_fuse(
self,
x: LatentsDomainGroupT,
selection_scores: Mapping[str, torch.Tensor] | None = None,
selection_module: SelectionBase | None = None,
) -> torch.Tensor:
return self.fuse(self.encode(x), selection_scores)
encodings = self.encode(x)
selection_scores = selection_module(x, encodings)
return self.fuse(encodings, selection_scores)

def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT:
"""
Expand Down Expand Up @@ -414,11 +419,11 @@ def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT:
}

def encode_and_fuse(
self,
x: LatentsDomainGroupT,
selection_scores: Mapping[str, torch.Tensor] | None = None,
self, x: LatentsDomainGroupT, selection_module=SelectionBase
) -> torch.Tensor:
return self.fuse(self.encode(x), selection_scores)
encodings = self.encode(x)
selection_scores = selection_module(x, encodings)
return self.fuse(encodings, selection_scores)

def encoded_distribution(
self, x: LatentsDomainGroupT
Expand Down
15 changes: 5 additions & 10 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,9 @@ def demi_cycle_loss(
continue
domain_name = next(iter(domains))

selection_scores = selection_mod(latents)
domain_mod = domain_mods[domain_name]
x_recons = gw_mod.decode(
gw_mod.encode_and_fuse(latents, selection_scores), domains={domain_name}
gw_mod.encode_and_fuse(latents, selection_mod), domains={domain_name}
)[domain_name]
loss_output = domain_mod.compute_dcy_loss(x_recons, latents[domain_name])
losses[f"demi_cycle_{domain_name}"] = loss_output.loss
Expand Down Expand Up @@ -125,18 +124,16 @@ def cycle_loss(
continue
domain_name_source = list(domains_source)[0]

selection_scores_source = selection_mod(latents_source)
domain_mod = domain_mods[domain_name_source]
z = gw_mod.encode_and_fuse(latents_source, selection_scores_source)
z = gw_mod.encode_and_fuse(latents_source, selection_mod)
for domain_name_target in domain_mods:
if domain_name_target == domain_name_source:
continue

x_pred = gw_mod.decode(z, domains={domain_name_target})

selection_scores_target = selection_mod(x_pred)
x_recons = gw_mod.decode(
gw_mod.encode_and_fuse(x_pred, selection_scores_target),
gw_mod.encode_and_fuse(x_pred, selection_mod),
domains={domain_name_source},
)

Expand Down Expand Up @@ -192,9 +189,7 @@ def translation_loss(
if domain != domain_name_target
}

selection_scores = selection_mod(domain_sources)

z = gw_mod.encode_and_fuse(domain_sources, selection_scores)
z = gw_mod.encode_and_fuse(domain_sources, selection_mod)
mod = domain_mods[domain_name_target]

domain_source_names = "/".join(domain_sources.keys())
Expand Down Expand Up @@ -802,7 +797,7 @@ def broadcast_loss(
selection_scores = self.selection_mod(scaled_latents)

encoded_latents_for_subset = self.gw_mod.encode_and_fuse(
scaled_latents, selection_scores
scaled_latents, self.selection_mod
)
encoded_latents_for_subset = torch.tanh(encoded_latents_for_subset)
decoded_latents_for_subset = self.gw_mod.decode(
Expand Down
22 changes: 16 additions & 6 deletions shimmer/modules/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def update_gw_state(self, gw_state: torch.Tensor) -> None:
pass

@abstractmethod
def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]:
def forward(
self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
) -> dict[str, torch.Tensor]:
"""
Forward pass of the selection method.
Expand All @@ -49,8 +51,10 @@ def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]:
...

# This is just for proper auto-completion...
def __call__(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]:
return super().__call__(domains)
def __call__(
self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
) -> dict[str, torch.Tensor]:
return super().__call__(domains, encodings_pre_fusion)


class SingleDomainSelection(SelectionBase):
Expand All @@ -62,7 +66,9 @@ class SingleDomainSelection(SelectionBase):
domain.
"""

def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]:
def forward(
self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
) -> dict[str, torch.Tensor]:
"""
Forward pass of the module.
Expand Down Expand Up @@ -114,7 +120,9 @@ def update_gw_state(self, gw_state: torch.Tensor) -> None:
"""
self.gw_state = gw_state

def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]:
def forward(
self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
) -> dict[str, torch.Tensor]:
"""
Compute keys and queries, match them with dot product and softmax.
Expand Down Expand Up @@ -173,7 +181,9 @@ def __init__(self, binary_proportion: float, temperature: float):
self.binary_proportion = binary_proportion
self.temperature = temperature

def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]:
def forward(
self, domains: LatentsDomainGroupT, encodings_pre_fusion: LatentsDomainGroupT
) -> dict[str, torch.Tensor]:
"""
randomly draw binary and uniform-then-domain-wise-softmaxed samples according
to self.binary_proportion.
Expand Down
7 changes: 3 additions & 4 deletions shimmer/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ def translation(
`torch.Tensor`: the translated unimodal representation
of the provided domain.
"""
selection_scores = selection_mod(x)
return gw_module.decode(
gw_module.encode_and_fuse(x, selection_scores), domains={to}
)[to]
return gw_module.decode(gw_module.encode_and_fuse(x, selection_mod), domains={to})[
to
]


def translation_with_uncertainty(
Expand Down
16 changes: 13 additions & 3 deletions tests/test_kq_onepass_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def test_single_domain():
attention.update_gw_state(gw_state)

single_domain_input = {"v_latents": torch.rand(batch_size, domain_dim)}
attention_scores = attention(single_domain_input)
encodings_pre_fusion = {"v_latents": torch.rand(batch_size, domain_dim)}
attention_scores = attention(single_domain_input, encodings_pre_fusion)

expected_scores = torch.ones(batch_size, 1)
assert torch.allclose(
Expand All @@ -33,7 +34,12 @@ def test_multiple_domains_sumis1():
"v_latents": torch.rand(batch_size, domain_dim),
"attr": torch.rand(batch_size, domain_dim),
}
attention_scores = attention(multiple_domain_input)
encodings_pre_fusion = {
"v_latents": torch.rand(batch_size, domain_dim),
"attr": torch.rand(batch_size, domain_dim),
}

attention_scores = attention(multiple_domain_input, encodings_pre_fusion)

scores_sum = sum(
attention_scores[domain].squeeze() for domain in multiple_domain_input
Expand Down Expand Up @@ -61,7 +67,11 @@ def test_attention_backward():
"attr": torch.rand(batch_size, domain_dim, requires_grad=True),
}

attention_scores = attention(domains)
encodings_pre_fusion = {
"v_latents": torch.rand(batch_size, domain_dim, requires_grad=True),
"attr": torch.rand(batch_size, domain_dim, requires_grad=True),
}
attention_scores = attention(domains, encodings_pre_fusion)
loss = sum(score.mean() for score in attention_scores.values())
assert isinstance(loss, torch.Tensor)
loss.backward()
Expand Down
26 changes: 23 additions & 3 deletions tests/test_random_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ def test_multiple_domains():
"v_latents": torch.rand(batch_size, domain_dim),
"attr": torch.rand(batch_size, domain_dim),
}
selection_scores = selection(multiple_domain_input)

prefusion_encodings = {
"v_latents": torch.rand(batch_size, domain_dim),
"attr": torch.rand(batch_size, domain_dim),
}

selection_scores = selection(multiple_domain_input, prefusion_encodings)

# Ensure the sum of attention scores across domains equals 1
scores_sum = sum(
Expand All @@ -42,7 +48,14 @@ def test_three_domains():
"attr": torch.rand(batch_size, domain_dim),
"audio": torch.rand(batch_size, domain_dim),
}
selection_scores = selection(three_domain_input)

prefusion_encodings = {
"v_latents": torch.rand(batch_size, domain_dim),
"attr": torch.rand(batch_size, domain_dim),
"audio": torch.rand(batch_size, domain_dim),
}

selection_scores = selection(three_domain_input, prefusion_encodings)

# Ensure that the shape of the selection scores matches the input domains
for domain in three_domain_input:
Expand Down Expand Up @@ -83,7 +96,14 @@ def test_binary_scores_xor_check_for_multiple_proportions():
"attr": torch.rand(batch_size, domain_dim),
"audio": torch.rand(batch_size, domain_dim),
}
selection_scores = selection(domains_input)

prefusion_encodings = {
"v_latents": torch.rand(batch_size, domain_dim),
"attr": torch.rand(batch_size, domain_dim),
"audio": torch.rand(batch_size, domain_dim),
}

selection_scores = selection(domains_input, prefusion_encodings)

scores_matrix = torch.cat(
[selection_scores[domain] for domain in domains_input], dim=1
Expand Down
13 changes: 10 additions & 3 deletions tests/test_single_domain_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ def test_selection_1_domain():

bs = 32
domains = {"v": torch.randn(bs, 8)}
prefusion_encodings = {"v": torch.randn(bs, 8)}

selection: dict[str, torch.Tensor] = selection_mod(domains)
selection: dict[str, torch.Tensor] = selection_mod(domains, prefusion_encodings)

assert len(selection) == len(domains)
assert next(iter(selection.keys())) == "v"
Expand All @@ -21,8 +22,9 @@ def test_selection_2_domains():

bs = 32
domains = {"v": torch.randn(bs, 8), "t": torch.randn(bs, 12)}
prefusion_encodings = {"v": torch.randn(bs, 8), "t": torch.randn(bs, 12)}

selection: dict[str, torch.Tensor] = selection_mod(domains)
selection: dict[str, torch.Tensor] = selection_mod(domains, prefusion_encodings)

assert len(selection) == len(domains)
assert (
Expand All @@ -39,8 +41,13 @@ def test_selection_3_domains():
"t": torch.randn(bs, 12),
"attr": torch.randn(bs, 4),
}
prefusion_encodings = {
"v": torch.randn(bs, 8),
"t": torch.randn(bs, 12),
"attr": torch.randn(bs, 4),
}

selection: dict[str, torch.Tensor] = selection_mod(domains)
selection: dict[str, torch.Tensor] = selection_mod(domains, prefusion_encodings)

assert len(selection) == len(domains)
assert (
Expand Down
8 changes: 3 additions & 5 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from utils import DummyData, DummyDataset, DummyDomainModule

from shimmer import GlobalWorkspace, GWDecoder, GWEncoder
from shimmer.modules.selection import SingleDomainSelection


def test_training():
Expand Down Expand Up @@ -82,11 +83,8 @@ def test_training():
assert isinstance(unimodal_latents["v"], torch.Tensor)
assert unimodal_latents["v"].size() == (32, 128)

selection_scores = {
domain: torch.full((batch_size,), 1.0 / len(unimodal_latents))
for domain in unimodal_latents
}
workspace_latent = gw.encode_and_fuse(unimodal_latents, selection_scores)
selection_module = SingleDomainSelection()
workspace_latent = gw.encode_and_fuse(unimodal_latents, selection_module)

assert workspace_latent.size() == (32, 16)

Expand Down
7 changes: 6 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.utils.data

from shimmer.modules import DomainModule
from shimmer.modules import DomainModule, SelectionBase

PROJECT_DIR = Path(__file__).resolve().parent.parent

Expand Down Expand Up @@ -36,3 +36,8 @@ def encode(self, x: DummyData) -> torch.Tensor:

def decode(self, z: torch.Tensor) -> DummyData:
return DummyData(vec=z)


class DummySelectionModule(SelectionBase):
def forward(self, x):
return {"v_latents": torch.ones(x["v_latents"].shape[0], 1)}

0 comments on commit 0913906

Please sign in to comment.