Skip to content

Commit

Permalink
Fix query attention test file (#47)
Browse files Browse the repository at this point in the history
* fixed mypi errors

* fixed mypy
larascipio authored Apr 4, 2024
1 parent c2e69a8 commit 8704f33
Showing 2 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion shimmer/modules/selection.py
Original file line number Diff line number Diff line change
@@ -274,7 +274,7 @@ def calculate_attention_dict(

def fuse_weighted_encodings(
self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
) -> torch.Tensor:
# Apply attention scores to the encodings
weighted_encodings = {}
for key in attention_dict:
12 changes: 5 additions & 7 deletions tests/test_query_attention.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,8 @@ def test_multiple_domains_sumis1():
scores_sum = sum(
attention_scores[domain].squeeze() for domain in multiple_domain_input
)
assert isinstance(scores_sum, torch.Tensor)

expected_sum = torch.ones(batch_size)

assert torch.allclose(
@@ -53,9 +55,9 @@ def test_attention_backward():
domain_dim = 12
head_size = 6
batch_size = 2056
domains = ["v_latents", "attr"]
domain_names = ["v_latents", "attr"]

attention = DynamicQueryAttention(batch_size, domain_dim, head_size, domains)
attention = DynamicQueryAttention(batch_size, domain_dim, head_size, domain_names)

domains = {
"v_latents": torch.rand(batch_size, domain_dim, requires_grad=True),
@@ -69,9 +71,5 @@ def test_attention_backward():
attention_scores = attention(domains, prefusion_encodings)

loss = sum(score.mean() for score in attention_scores.values())
loss.backward()

for name, param in attention.named_parameters():
assert (
param.grad is not None
), f"Gradients should be computed for parameter '{name}'"
assert isinstance(loss, torch.Tensor)

0 comments on commit 8704f33

Please sign in to comment.