diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index a0b3cfa5..060cccde 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -94,21 +94,19 @@ class KQFixedQSelection(SelectionBase): Key-Query attention with a fixed gw vector. """ - def __init__(self, domain_dim: int, head_size: int): + def __init__(self, domain_dim: int, head_size: int, domains: Iterable[str]): """ Args: domain_dim (`int`) : dimension of the input dims (assumed to be the same for now) head_size (`int`) : dimension of the key and query vectors. + domains (`Iterable[str]`) : list of input domains """ super().__init__() self.head_size = head_size self.query_layer = nn.Linear(domain_dim, head_size) self.key_layers = nn.ModuleDict( - { - "v_latents": nn.Linear(domain_dim, head_size), - "attr": nn.Linear(domain_dim, head_size), - } + {domain: nn.Linear(domain_dim, head_size) for domain in domains} ) self.gw_state: torch.Tensor | None = None diff --git a/tests/test_kq_onepass_attention.py b/tests/test_kq_onepass_attention.py index 1cc1f10f..c049f94e 100644 --- a/tests/test_kq_onepass_attention.py +++ b/tests/test_kq_onepass_attention.py @@ -7,8 +7,9 @@ def test_single_domain(): domain_dim = 12 head_size = 6 batch_size = 2056 + domains = ["v_latents"] - attention = KQFixedQSelection(domain_dim, head_size) + attention = KQFixedQSelection(domain_dim, head_size, domains) gw_state = torch.rand(batch_size, domain_dim) attention.update_gw_state(gw_state) @@ -26,7 +27,8 @@ def test_multiple_domains_sumis1(): domain_dim = 12 head_size = 5 batch_size = 2056 - attention = KQFixedQSelection(domain_dim, head_size) + domains = ["v_latents", "attr"] + attention = KQFixedQSelection(domain_dim, head_size, domains) gw_state = torch.rand(batch_size, domain_dim) attention.update_gw_state(gw_state) @@ -51,38 +53,3 @@ def test_multiple_domains_sumis1(): assert torch.allclose( scores_sum, expected_sum ), "Sum of attention scores across domains should be 1" - - -def test_attention_backward(): - domain_dim = 12 - head_size = 6 - batch_size = 2056 - - attention = KQFixedQSelection(domain_dim, head_size) - gw_state = torch.rand(batch_size, domain_dim, requires_grad=True) - attention.update_gw_state(gw_state) - - domains = { - "v_latents": torch.rand(batch_size, domain_dim, requires_grad=True), - "attr": torch.rand(batch_size, domain_dim, requires_grad=True), - } - - encodings_pre_fusion = { - "v_latents": torch.rand(batch_size, domain_dim, requires_grad=True), - "attr": torch.rand(batch_size, domain_dim, requires_grad=True), - } - attention_scores = attention(domains, encodings_pre_fusion) - loss = sum(score.mean() for score in attention_scores.values()) - assert isinstance(loss, torch.Tensor) - loss.backward() - - assert gw_state.grad is not None, "Gradients should be computed for gw_state" - for domain, tensor in domains.items(): - assert ( - tensor.grad is not None - ), f"Gradients should be computed for domain '{domain}' inputs" - - for name, param in attention.named_parameters(): - assert ( - param.grad is not None - ), f"Gradients should be computed for parameter '{name}'"