From 1b7bf0a4f66e3fc6d394547588263bd204dbf564 Mon Sep 17 00:00:00 2001 From: RolandBERTINJOHANNET <78405585+RolandBERTINJOHANNET@users.noreply.github.com> Date: Thu, 28 Mar 2024 17:28:39 +0100 Subject: [PATCH] Implement attention mechanism (#26) * first draft attention mechanism * first draft (not running) * ruff fixes * fixed errors in classification_loss * got rid of attention-specific globalworkspace * ruff fixes (again) * got rid of task-specific gwlosses * first draft for attention (far from finished) * will need to resolve conflicts with upcoming selectionbase branch * get rid of checkpoint files * global workspace no longer has the attention mechanism * rebase (for self.gw_states instead of forward arg) * rebase (to get the self.gw_state instead of arg to forward) * correct KQAttention module (todo docstring) * added test for KQattention * ruff checks * trying ruff format again.. * forgot to ruff the tests * few mypy fixes (still wont work) * took out BinaryAttention -no longer needed * fixed rebase's code duplications * fixing rebase issues * conform to base class with changed gw_state * docstring + rewrote tests for new class * random selection (todo tests & docstring) * finished randomAttention minus docstring * ruff format on attention test file * added docstring for randomselection * took out useless call in tests/test_random... * requested changes * adapted kq tests to new class name * requested fixes --------- Co-authored-by: Lara --- shimmer/modules/selection.py | 132 +++++++++++++++++++++++++++++ tests/test_kq_onepass_attention.py | 75 ++++++++++++++++ tests/test_random_attention.py | 94 ++++++++++++++++++++ 3 files changed, 301 insertions(+) create mode 100644 tests/test_kq_onepass_attention.py create mode 100644 tests/test_random_attention.py diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index c4b6c073..43e1fc3e 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -1,8 +1,10 @@ from abc import ABC, abstractmethod import torch +import torch.nn as nn from shimmer.types import LatentsDomainGroupT +from shimmer.utils import group_batch_size, group_device class SelectionBase(torch.nn.Module, ABC): @@ -45,3 +47,133 @@ def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]: {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])} """ ... + + +class KQFixedQSelection(SelectionBase): + """ + Key-Query attention with a fixed gw vector. + """ + + def __init__(self, domain_dim: int, head_size: int): + """ + Args: + domain_dim (`int`) : dimension of the input dims (assumed to be the same for now) + head_size (`int`) : dimension of the key and query vectors. + """ + super().__init__() + self.head_size = head_size + self.query_layer = nn.Linear(domain_dim, head_size) + self.key_layers = nn.ModuleDict( + { + "v_latents": nn.Linear(domain_dim, head_size), + "attr": nn.Linear(domain_dim, head_size), + } + ) + self.gw_state: torch.Tensor | None = None + + def update_gw_state(self, gw_state: torch.Tensor) -> None: + """ + Set the internal copy of the fixed gw state. You're meant to only call this once + + Args: + gw_state (`torch.Tensor`): the previous GW state + """ + self.gw_state = gw_state + + def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]: + """ + Compute keys and queries, match them with dot product and softmax. + + Args: + domains (`LatentsDomainGroupT`): Group of unimodal latent representations. + + Returns: + `dict[str, torch.Tensor]`: for each domain in the group, the fusion + coefficient for each item in the batch. + """ + + if self.gw_state is None: + raise ValueError("GW state has not been initialized.") + + keys = { + domain: self.key_layers[domain](encoding) + for domain, encoding in domains.items() + } + + device = group_device(domains) + query = self.query_layer(self.gw_state.to(device)) + + dot_products = { + domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze() + for domain, key in keys.items() + } + + dot_products_tensor = torch.stack(list(dot_products.values()), dim=1) + + attention_scores = torch.softmax(dot_products_tensor, dim=1) + + attention_dict = { + domain: attention_scores[:, i : i + 1] for i, domain in enumerate(keys) + } + + return attention_dict + + +class RandomSelection(SelectionBase): + """ + random attention, not learned, with a proportion of binary scaling factors, and a proportion of uniform-then-softmaxed-across-modalities scores. + this class serves to train broadcast with robustness on linear scaling on prefusion representations. + """ + + def __init__(self, binary_proportion: float, temperature: float): + """ + Args: + binary_proportion (`float`) : proportion of binary scaling factors returned by forward(). between 0 and 1. + temperature (`float`) : temperature of the softmax applied to uniform scaling factors. + """ + super().__init__() + self.binary_proportion = binary_proportion + self.temperature = temperature + + def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]: + """ + randomly draw binary and uniform-then-domain-wise-softmaxed samples according to self.binary_proportion. + + Args: + domains (`LatentsDomainGroupT`): Group of unimodal latent representations. This is not used in the function. + + Returns: + `dict[str, torch.Tensor]`: for each domain in the group, the fusion + coefficient for each item in the batch. + """ + num_domains = len(domains) + batch_size = group_batch_size(domains) + + # have to add extra binaries when the division's not integer + total_binary_scores = int(batch_size * self.binary_proportion) + num_binary_per_domain, extra_binary_scores = divmod( + total_binary_scores, num_domains + ) + + # Calculate number of uniform scores taking into account extra binary scores + num_uniform = batch_size - total_binary_scores + + uniform_scores = torch.rand(num_uniform, num_domains) + softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1) + + # Generate binary scores, adjusting for any extra binary scores + binary_scores = [] + for i in range(num_domains): + binary_score = torch.zeros( + num_binary_per_domain + (1 if i < extra_binary_scores else 0), + num_domains, + ) + binary_score[:, i] = 1 + binary_scores.append(binary_score) + binary_scores_concat = torch.cat(binary_scores, dim=0) + + all_scores = torch.cat([softmax_scores, binary_scores_concat], dim=0) + attention_dict = { + domain: all_scores[:, i : i + 1] for i, domain in enumerate(domains) + } + return attention_dict diff --git a/tests/test_kq_onepass_attention.py b/tests/test_kq_onepass_attention.py new file mode 100644 index 00000000..8a35bdc8 --- /dev/null +++ b/tests/test_kq_onepass_attention.py @@ -0,0 +1,75 @@ +import torch + +from shimmer.modules.selection import KQFixedQSelection + + +def test_single_domain(): + domain_dim = 12 + head_size = 6 + batch_size = 2056 + + attention = KQFixedQSelection(domain_dim, head_size) + gw_state = torch.rand(batch_size, domain_dim) + attention.update_gw_state(gw_state) + + single_domain_input = {"v_latents": torch.rand(batch_size, domain_dim)} + attention_scores = attention(single_domain_input) + + expected_scores = torch.ones(batch_size, 1) + assert torch.allclose( + attention_scores["v_latents"], expected_scores + ), "Attention scores for single domain should be all 1s" + + +def test_multiple_domains_sumis1(): + domain_dim = 12 + head_size = 5 + batch_size = 2056 + attention = KQFixedQSelection(domain_dim, head_size) + gw_state = torch.rand(batch_size, domain_dim) + attention.update_gw_state(gw_state) + + multiple_domain_input = { + "v_latents": torch.rand(batch_size, domain_dim), + "attr": torch.rand(batch_size, domain_dim), + } + attention_scores = attention(multiple_domain_input) + + scores_sum = sum( + attention_scores[domain].squeeze() for domain in multiple_domain_input.keys() + ) + expected_sum = torch.ones(batch_size) + + assert torch.allclose( + scores_sum, expected_sum + ), "Sum of attention scores across domains should be 1" + + +def test_attention_backward(): + domain_dim = 12 + head_size = 6 + batch_size = 2056 + + attention = KQFixedQSelection(domain_dim, head_size) + gw_state = torch.rand(batch_size, domain_dim, requires_grad=True) + attention.update_gw_state(gw_state) + + domains = { + "v_latents": torch.rand(batch_size, domain_dim, requires_grad=True), + "attr": torch.rand(batch_size, domain_dim, requires_grad=True), + } + + attention_scores = attention(domains) + loss = sum(score.mean() for score in attention_scores.values()) + loss.backward() + + assert gw_state.grad is not None, "Gradients should be computed for gw_state" + for domain, tensor in domains.items(): + assert ( + tensor.grad is not None + ), f"Gradients should be computed for domain '{domain}' inputs" + + for name, param in attention.named_parameters(): + assert ( + param.grad is not None + ), f"Gradients should be computed for parameter '{name}'" diff --git a/tests/test_random_attention.py b/tests/test_random_attention.py new file mode 100644 index 00000000..ea11cded --- /dev/null +++ b/tests/test_random_attention.py @@ -0,0 +1,94 @@ +import numpy as np +import torch + +from shimmer.modules.selection import RandomSelection + + +def test_multiple_domains(): + binary_proportion = 0.5 + temperature = 1.0 + domain_dim = 12 + batch_size = 2056 + + selection = RandomSelection(binary_proportion, temperature) + multiple_domain_input = { + "v_latents": torch.rand(batch_size, domain_dim), + "attr": torch.rand(batch_size, domain_dim), + } + selection_scores = selection(multiple_domain_input) + + # Ensure the sum of attention scores across domains equals 1 + scores_sum = sum( + selection_scores[domain].squeeze() for domain in multiple_domain_input.keys() + ) + expected_sum = torch.ones(batch_size) + + assert torch.allclose( + scores_sum, expected_sum + ), "Sum of selection scores across domains should be 1" + + +def test_three_domains(): + binary_proportion = 0.5 + temperature = 1.0 + domain_dim = 12 + batch_size = 2056 + + selection = RandomSelection(binary_proportion, temperature) + three_domain_input = { + "v_latents": torch.rand(batch_size, domain_dim), + "attr": torch.rand(batch_size, domain_dim), + "audio": torch.rand(batch_size, domain_dim), + } + selection_scores = selection(three_domain_input) + + # Ensure that the shape of the selection scores matches the input domains + for domain in three_domain_input.keys(): + assert selection_scores[domain].shape == ( + batch_size, + 1, + ), f"Scores shape mismatch for {domain}" + + # Check if the binary scores are as expected + # This part might need adjustments based on how binary scores are distributed + # and combined with uniform scores in your actual implementation + + # Check if the sum of selection scores across domains equals 1 + scores_sum = sum( + selection_scores[domain].squeeze() for domain in three_domain_input.keys() + ) + expected_sum = torch.ones(batch_size) + + assert torch.allclose( + scores_sum, expected_sum + ), "Sum of selection scores across three domains should be 1" + + +def test_binary_scores_xor_check_for_multiple_proportions(): + temperature = 1.0 + domain_dim = 12 + batch_size = 2056 + num_tests = 10 # Number of random proportions to test + + for _ in range(num_tests): + binary_proportion = np.random.rand() # Random proportion between 0 and 1 + + selection = RandomSelection(binary_proportion, temperature) + domains_input = { + "v_latents": torch.rand(batch_size, domain_dim), + "attr": torch.rand(batch_size, domain_dim), + "audio": torch.rand(batch_size, domain_dim), + } + selection_scores = selection(domains_input) + + scores_matrix = torch.cat( + [selection_scores[domain] for domain in domains_input.keys()], dim=1 + ) + binary_scores_mask = scores_matrix == 1 + xor_binary_check = binary_scores_mask.sum(dim=1) == 1 + num_binary_rows = xor_binary_check.sum().item() + expected_num_binary_rows = int(batch_size * binary_proportion) + + assert ( + num_binary_rows == expected_num_binary_rows + ), f"Incorrect number of binary score rows for proportion {binary_proportion:.2f}: expected {expected_num_binary_rows}, got {num_binary_rows}"