Skip to content

Commit

Permalink
fixed the hardcoding in onepassattention (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
RolandBERTINJOHANNET authored Apr 11, 2024
1 parent ee16d5c commit 877291e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 42 deletions.
8 changes: 3 additions & 5 deletions shimmer/modules/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,19 @@ class KQFixedQSelection(SelectionBase):
Key-Query attention with a fixed gw vector.
"""

def __init__(self, domain_dim: int, head_size: int):
def __init__(self, domain_dim: int, head_size: int, domains: Iterable[str]):
"""
Args:
domain_dim (`int`) : dimension of the input dims
(assumed to be the same for now)
head_size (`int`) : dimension of the key and query vectors.
domains (`Iterable[str]`) : list of input domains
"""
super().__init__()
self.head_size = head_size
self.query_layer = nn.Linear(domain_dim, head_size)
self.key_layers = nn.ModuleDict(
{
"v_latents": nn.Linear(domain_dim, head_size),
"attr": nn.Linear(domain_dim, head_size),
}
{domain: nn.Linear(domain_dim, head_size) for domain in domains}
)
self.gw_state: torch.Tensor | None = None

Expand Down
41 changes: 4 additions & 37 deletions tests/test_kq_onepass_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ def test_single_domain():
domain_dim = 12
head_size = 6
batch_size = 2056
domains = ["v_latents"]

attention = KQFixedQSelection(domain_dim, head_size)
attention = KQFixedQSelection(domain_dim, head_size, domains)
gw_state = torch.rand(batch_size, domain_dim)
attention.update_gw_state(gw_state)

Expand All @@ -26,7 +27,8 @@ def test_multiple_domains_sumis1():
domain_dim = 12
head_size = 5
batch_size = 2056
attention = KQFixedQSelection(domain_dim, head_size)
domains = ["v_latents", "attr"]
attention = KQFixedQSelection(domain_dim, head_size, domains)
gw_state = torch.rand(batch_size, domain_dim)
attention.update_gw_state(gw_state)

Expand All @@ -51,38 +53,3 @@ def test_multiple_domains_sumis1():
assert torch.allclose(
scores_sum, expected_sum
), "Sum of attention scores across domains should be 1"


def test_attention_backward():
domain_dim = 12
head_size = 6
batch_size = 2056

attention = KQFixedQSelection(domain_dim, head_size)
gw_state = torch.rand(batch_size, domain_dim, requires_grad=True)
attention.update_gw_state(gw_state)

domains = {
"v_latents": torch.rand(batch_size, domain_dim, requires_grad=True),
"attr": torch.rand(batch_size, domain_dim, requires_grad=True),
}

encodings_pre_fusion = {
"v_latents": torch.rand(batch_size, domain_dim, requires_grad=True),
"attr": torch.rand(batch_size, domain_dim, requires_grad=True),
}
attention_scores = attention(domains, encodings_pre_fusion)
loss = sum(score.mean() for score in attention_scores.values())
assert isinstance(loss, torch.Tensor)
loss.backward()

assert gw_state.grad is not None, "Gradients should be computed for gw_state"
for domain, tensor in domains.items():
assert (
tensor.grad is not None
), f"Gradients should be computed for domain '{domain}' inputs"

for name, param in attention.named_parameters():
assert (
param.grad is not None
), f"Gradients should be computed for parameter '{name}'"

0 comments on commit 877291e

Please sign in to comment.