-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* first draft attention mechanism * first draft (not running) * ruff fixes * fixed errors in classification_loss * got rid of attention-specific globalworkspace * ruff fixes (again) * got rid of task-specific gwlosses * first draft for attention (far from finished) * will need to resolve conflicts with upcoming selectionbase branch * get rid of checkpoint files * global workspace no longer has the attention mechanism * rebase (for self.gw_states instead of forward arg) * rebase (to get the self.gw_state instead of arg to forward) * correct KQAttention module (todo docstring) * added test for KQattention * ruff checks * trying ruff format again.. * forgot to ruff the tests * few mypy fixes (still wont work) * took out BinaryAttention -no longer needed * fixed rebase's code duplications * fixing rebase issues * conform to base class with changed gw_state * docstring + rewrote tests for new class * random selection (todo tests & docstring) * finished randomAttention minus docstring * ruff format on attention test file * added docstring for randomselection * took out useless call in tests/test_random... * requested changes * adapted kq tests to new class name * requested fixes --------- Co-authored-by: Lara <[email protected]>
- Loading branch information
1 parent
d4b16b2
commit 1b7bf0a
Showing
3 changed files
with
301 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import torch | ||
|
||
from shimmer.modules.selection import KQFixedQSelection | ||
|
||
|
||
def test_single_domain(): | ||
domain_dim = 12 | ||
head_size = 6 | ||
batch_size = 2056 | ||
|
||
attention = KQFixedQSelection(domain_dim, head_size) | ||
gw_state = torch.rand(batch_size, domain_dim) | ||
attention.update_gw_state(gw_state) | ||
|
||
single_domain_input = {"v_latents": torch.rand(batch_size, domain_dim)} | ||
attention_scores = attention(single_domain_input) | ||
|
||
expected_scores = torch.ones(batch_size, 1) | ||
assert torch.allclose( | ||
attention_scores["v_latents"], expected_scores | ||
), "Attention scores for single domain should be all 1s" | ||
|
||
|
||
def test_multiple_domains_sumis1(): | ||
domain_dim = 12 | ||
head_size = 5 | ||
batch_size = 2056 | ||
attention = KQFixedQSelection(domain_dim, head_size) | ||
gw_state = torch.rand(batch_size, domain_dim) | ||
attention.update_gw_state(gw_state) | ||
|
||
multiple_domain_input = { | ||
"v_latents": torch.rand(batch_size, domain_dim), | ||
"attr": torch.rand(batch_size, domain_dim), | ||
} | ||
attention_scores = attention(multiple_domain_input) | ||
|
||
scores_sum = sum( | ||
attention_scores[domain].squeeze() for domain in multiple_domain_input.keys() | ||
) | ||
expected_sum = torch.ones(batch_size) | ||
|
||
assert torch.allclose( | ||
scores_sum, expected_sum | ||
), "Sum of attention scores across domains should be 1" | ||
|
||
|
||
def test_attention_backward(): | ||
domain_dim = 12 | ||
head_size = 6 | ||
batch_size = 2056 | ||
|
||
attention = KQFixedQSelection(domain_dim, head_size) | ||
gw_state = torch.rand(batch_size, domain_dim, requires_grad=True) | ||
attention.update_gw_state(gw_state) | ||
|
||
domains = { | ||
"v_latents": torch.rand(batch_size, domain_dim, requires_grad=True), | ||
"attr": torch.rand(batch_size, domain_dim, requires_grad=True), | ||
} | ||
|
||
attention_scores = attention(domains) | ||
loss = sum(score.mean() for score in attention_scores.values()) | ||
loss.backward() | ||
|
||
assert gw_state.grad is not None, "Gradients should be computed for gw_state" | ||
for domain, tensor in domains.items(): | ||
assert ( | ||
tensor.grad is not None | ||
), f"Gradients should be computed for domain '{domain}' inputs" | ||
|
||
for name, param in attention.named_parameters(): | ||
assert ( | ||
param.grad is not None | ||
), f"Gradients should be computed for parameter '{name}'" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import numpy as np | ||
import torch | ||
|
||
from shimmer.modules.selection import RandomSelection | ||
|
||
|
||
def test_multiple_domains(): | ||
binary_proportion = 0.5 | ||
temperature = 1.0 | ||
domain_dim = 12 | ||
batch_size = 2056 | ||
|
||
selection = RandomSelection(binary_proportion, temperature) | ||
multiple_domain_input = { | ||
"v_latents": torch.rand(batch_size, domain_dim), | ||
"attr": torch.rand(batch_size, domain_dim), | ||
} | ||
selection_scores = selection(multiple_domain_input) | ||
|
||
# Ensure the sum of attention scores across domains equals 1 | ||
scores_sum = sum( | ||
selection_scores[domain].squeeze() for domain in multiple_domain_input.keys() | ||
) | ||
expected_sum = torch.ones(batch_size) | ||
|
||
assert torch.allclose( | ||
scores_sum, expected_sum | ||
), "Sum of selection scores across domains should be 1" | ||
|
||
|
||
def test_three_domains(): | ||
binary_proportion = 0.5 | ||
temperature = 1.0 | ||
domain_dim = 12 | ||
batch_size = 2056 | ||
|
||
selection = RandomSelection(binary_proportion, temperature) | ||
three_domain_input = { | ||
"v_latents": torch.rand(batch_size, domain_dim), | ||
"attr": torch.rand(batch_size, domain_dim), | ||
"audio": torch.rand(batch_size, domain_dim), | ||
} | ||
selection_scores = selection(three_domain_input) | ||
|
||
# Ensure that the shape of the selection scores matches the input domains | ||
for domain in three_domain_input.keys(): | ||
assert selection_scores[domain].shape == ( | ||
batch_size, | ||
1, | ||
), f"Scores shape mismatch for {domain}" | ||
|
||
# Check if the binary scores are as expected | ||
# This part might need adjustments based on how binary scores are distributed | ||
# and combined with uniform scores in your actual implementation | ||
|
||
# Check if the sum of selection scores across domains equals 1 | ||
scores_sum = sum( | ||
selection_scores[domain].squeeze() for domain in three_domain_input.keys() | ||
) | ||
expected_sum = torch.ones(batch_size) | ||
|
||
assert torch.allclose( | ||
scores_sum, expected_sum | ||
), "Sum of selection scores across three domains should be 1" | ||
|
||
|
||
def test_binary_scores_xor_check_for_multiple_proportions(): | ||
temperature = 1.0 | ||
domain_dim = 12 | ||
batch_size = 2056 | ||
num_tests = 10 # Number of random proportions to test | ||
|
||
for _ in range(num_tests): | ||
binary_proportion = np.random.rand() # Random proportion between 0 and 1 | ||
|
||
selection = RandomSelection(binary_proportion, temperature) | ||
domains_input = { | ||
"v_latents": torch.rand(batch_size, domain_dim), | ||
"attr": torch.rand(batch_size, domain_dim), | ||
"audio": torch.rand(batch_size, domain_dim), | ||
} | ||
selection_scores = selection(domains_input) | ||
|
||
scores_matrix = torch.cat( | ||
[selection_scores[domain] for domain in domains_input.keys()], dim=1 | ||
) | ||
binary_scores_mask = scores_matrix == 1 | ||
xor_binary_check = binary_scores_mask.sum(dim=1) == 1 | ||
num_binary_rows = xor_binary_check.sum().item() | ||
expected_num_binary_rows = int(batch_size * binary_proportion) | ||
|
||
assert ( | ||
num_binary_rows == expected_num_binary_rows | ||
), f"Incorrect number of binary score rows for proportion {binary_proportion:.2f}: expected {expected_num_binary_rows}, got {num_binary_rows}" |