Skip to content

Commit

Permalink
Implement attention mechanism (#26)
Browse files Browse the repository at this point in the history
* first draft attention mechanism

* first draft (not running)

* ruff fixes

* fixed errors in classification_loss

* got rid of attention-specific globalworkspace

* ruff fixes (again)

* got rid of task-specific gwlosses

* first draft for attention (far from finished)

* will need to resolve conflicts with upcoming selectionbase branch

* get rid of checkpoint files

* global workspace no longer has the attention mechanism

* rebase (for self.gw_states instead of forward arg)

* rebase (to get the self.gw_state instead of arg to forward)

* correct KQAttention module (todo docstring)

* added test for KQattention

* ruff checks

* trying ruff format again..

* forgot to ruff the tests

* few mypy fixes (still wont work)

* took out BinaryAttention -no longer needed

* fixed rebase's code duplications

* fixing rebase issues

* conform to base class with changed gw_state

* docstring + rewrote tests for new class

* random selection (todo tests & docstring)

* finished randomAttention minus docstring

* ruff format on attention test file

* added docstring for randomselection

* took out useless call in tests/test_random...

* requested changes

* adapted kq tests to new class name

* requested fixes

---------

Co-authored-by: Lara <[email protected]>
  • Loading branch information
RolandBERTINJOHANNET and larascipio authored Mar 28, 2024
1 parent d4b16b2 commit 1b7bf0a
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 0 deletions.
132 changes: 132 additions & 0 deletions shimmer/modules/selection.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import ABC, abstractmethod

import torch
import torch.nn as nn

from shimmer.types import LatentsDomainGroupT
from shimmer.utils import group_batch_size, group_device


class SelectionBase(torch.nn.Module, ABC):
Expand Down Expand Up @@ -45,3 +47,133 @@ def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]:
{"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
"""
...


class KQFixedQSelection(SelectionBase):
"""
Key-Query attention with a fixed gw vector.
"""

def __init__(self, domain_dim: int, head_size: int):
"""
Args:
domain_dim (`int`) : dimension of the input dims (assumed to be the same for now)
head_size (`int`) : dimension of the key and query vectors.
"""
super().__init__()
self.head_size = head_size
self.query_layer = nn.Linear(domain_dim, head_size)
self.key_layers = nn.ModuleDict(
{
"v_latents": nn.Linear(domain_dim, head_size),
"attr": nn.Linear(domain_dim, head_size),
}
)
self.gw_state: torch.Tensor | None = None

def update_gw_state(self, gw_state: torch.Tensor) -> None:
"""
Set the internal copy of the fixed gw state. You're meant to only call this once
Args:
gw_state (`torch.Tensor`): the previous GW state
"""
self.gw_state = gw_state

def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]:
"""
Compute keys and queries, match them with dot product and softmax.
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
Returns:
`dict[str, torch.Tensor]`: for each domain in the group, the fusion
coefficient for each item in the batch.
"""

if self.gw_state is None:
raise ValueError("GW state has not been initialized.")

keys = {
domain: self.key_layers[domain](encoding)
for domain, encoding in domains.items()
}

device = group_device(domains)
query = self.query_layer(self.gw_state.to(device))

dot_products = {
domain: torch.bmm(key.unsqueeze(1), query.unsqueeze(2)).squeeze()
for domain, key in keys.items()
}

dot_products_tensor = torch.stack(list(dot_products.values()), dim=1)

attention_scores = torch.softmax(dot_products_tensor, dim=1)

attention_dict = {
domain: attention_scores[:, i : i + 1] for i, domain in enumerate(keys)
}

return attention_dict


class RandomSelection(SelectionBase):
"""
random attention, not learned, with a proportion of binary scaling factors, and a proportion of uniform-then-softmaxed-across-modalities scores.
this class serves to train broadcast with robustness on linear scaling on prefusion representations.
"""

def __init__(self, binary_proportion: float, temperature: float):
"""
Args:
binary_proportion (`float`) : proportion of binary scaling factors returned by forward(). between 0 and 1.
temperature (`float`) : temperature of the softmax applied to uniform scaling factors.
"""
super().__init__()
self.binary_proportion = binary_proportion
self.temperature = temperature

def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]:
"""
randomly draw binary and uniform-then-domain-wise-softmaxed samples according to self.binary_proportion.
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations. This is not used in the function.
Returns:
`dict[str, torch.Tensor]`: for each domain in the group, the fusion
coefficient for each item in the batch.
"""
num_domains = len(domains)
batch_size = group_batch_size(domains)

# have to add extra binaries when the division's not integer
total_binary_scores = int(batch_size * self.binary_proportion)
num_binary_per_domain, extra_binary_scores = divmod(
total_binary_scores, num_domains
)

# Calculate number of uniform scores taking into account extra binary scores
num_uniform = batch_size - total_binary_scores

uniform_scores = torch.rand(num_uniform, num_domains)
softmax_scores = torch.softmax(uniform_scores / self.temperature, dim=1)

# Generate binary scores, adjusting for any extra binary scores
binary_scores = []
for i in range(num_domains):
binary_score = torch.zeros(
num_binary_per_domain + (1 if i < extra_binary_scores else 0),
num_domains,
)
binary_score[:, i] = 1
binary_scores.append(binary_score)
binary_scores_concat = torch.cat(binary_scores, dim=0)

all_scores = torch.cat([softmax_scores, binary_scores_concat], dim=0)
attention_dict = {
domain: all_scores[:, i : i + 1] for i, domain in enumerate(domains)
}
return attention_dict
75 changes: 75 additions & 0 deletions tests/test_kq_onepass_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch

from shimmer.modules.selection import KQFixedQSelection


def test_single_domain():
domain_dim = 12
head_size = 6
batch_size = 2056

attention = KQFixedQSelection(domain_dim, head_size)
gw_state = torch.rand(batch_size, domain_dim)
attention.update_gw_state(gw_state)

single_domain_input = {"v_latents": torch.rand(batch_size, domain_dim)}
attention_scores = attention(single_domain_input)

expected_scores = torch.ones(batch_size, 1)
assert torch.allclose(
attention_scores["v_latents"], expected_scores
), "Attention scores for single domain should be all 1s"


def test_multiple_domains_sumis1():
domain_dim = 12
head_size = 5
batch_size = 2056
attention = KQFixedQSelection(domain_dim, head_size)
gw_state = torch.rand(batch_size, domain_dim)
attention.update_gw_state(gw_state)

multiple_domain_input = {
"v_latents": torch.rand(batch_size, domain_dim),
"attr": torch.rand(batch_size, domain_dim),
}
attention_scores = attention(multiple_domain_input)

scores_sum = sum(
attention_scores[domain].squeeze() for domain in multiple_domain_input.keys()
)
expected_sum = torch.ones(batch_size)

assert torch.allclose(
scores_sum, expected_sum
), "Sum of attention scores across domains should be 1"


def test_attention_backward():
domain_dim = 12
head_size = 6
batch_size = 2056

attention = KQFixedQSelection(domain_dim, head_size)
gw_state = torch.rand(batch_size, domain_dim, requires_grad=True)
attention.update_gw_state(gw_state)

domains = {
"v_latents": torch.rand(batch_size, domain_dim, requires_grad=True),
"attr": torch.rand(batch_size, domain_dim, requires_grad=True),
}

attention_scores = attention(domains)
loss = sum(score.mean() for score in attention_scores.values())
loss.backward()

assert gw_state.grad is not None, "Gradients should be computed for gw_state"
for domain, tensor in domains.items():
assert (
tensor.grad is not None
), f"Gradients should be computed for domain '{domain}' inputs"

for name, param in attention.named_parameters():
assert (
param.grad is not None
), f"Gradients should be computed for parameter '{name}'"
94 changes: 94 additions & 0 deletions tests/test_random_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
import torch

from shimmer.modules.selection import RandomSelection


def test_multiple_domains():
binary_proportion = 0.5
temperature = 1.0
domain_dim = 12
batch_size = 2056

selection = RandomSelection(binary_proportion, temperature)
multiple_domain_input = {
"v_latents": torch.rand(batch_size, domain_dim),
"attr": torch.rand(batch_size, domain_dim),
}
selection_scores = selection(multiple_domain_input)

# Ensure the sum of attention scores across domains equals 1
scores_sum = sum(
selection_scores[domain].squeeze() for domain in multiple_domain_input.keys()
)
expected_sum = torch.ones(batch_size)

assert torch.allclose(
scores_sum, expected_sum
), "Sum of selection scores across domains should be 1"


def test_three_domains():
binary_proportion = 0.5
temperature = 1.0
domain_dim = 12
batch_size = 2056

selection = RandomSelection(binary_proportion, temperature)
three_domain_input = {
"v_latents": torch.rand(batch_size, domain_dim),
"attr": torch.rand(batch_size, domain_dim),
"audio": torch.rand(batch_size, domain_dim),
}
selection_scores = selection(three_domain_input)

# Ensure that the shape of the selection scores matches the input domains
for domain in three_domain_input.keys():
assert selection_scores[domain].shape == (
batch_size,
1,
), f"Scores shape mismatch for {domain}"

# Check if the binary scores are as expected
# This part might need adjustments based on how binary scores are distributed
# and combined with uniform scores in your actual implementation

# Check if the sum of selection scores across domains equals 1
scores_sum = sum(
selection_scores[domain].squeeze() for domain in three_domain_input.keys()
)
expected_sum = torch.ones(batch_size)

assert torch.allclose(
scores_sum, expected_sum
), "Sum of selection scores across three domains should be 1"


def test_binary_scores_xor_check_for_multiple_proportions():
temperature = 1.0
domain_dim = 12
batch_size = 2056
num_tests = 10 # Number of random proportions to test

for _ in range(num_tests):
binary_proportion = np.random.rand() # Random proportion between 0 and 1

selection = RandomSelection(binary_proportion, temperature)
domains_input = {
"v_latents": torch.rand(batch_size, domain_dim),
"attr": torch.rand(batch_size, domain_dim),
"audio": torch.rand(batch_size, domain_dim),
}
selection_scores = selection(domains_input)

scores_matrix = torch.cat(
[selection_scores[domain] for domain in domains_input.keys()], dim=1
)
binary_scores_mask = scores_matrix == 1
xor_binary_check = binary_scores_mask.sum(dim=1) == 1
num_binary_rows = xor_binary_check.sum().item()
expected_num_binary_rows = int(batch_size * binary_proportion)

assert (
num_binary_rows == expected_num_binary_rows
), f"Incorrect number of binary score rows for proportion {binary_proportion:.2f}: expected {expected_num_binary_rows}, got {num_binary_rows}"

0 comments on commit 1b7bf0a

Please sign in to comment.