Skip to content

Commit

Permalink
add step_limit to dynamicqueryattention
Browse files Browse the repository at this point in the history
  • Loading branch information
RolandBERTINJOHANNET committed Sep 26, 2024
1 parent ca2ab8c commit c3cad6e
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion shimmer/modules/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,21 @@ def __init__(
{domain: nn.Linear(domain_dim, head_size) for domain in domain_names}
)
self.n_steps = n_steps
self.step_limit = n_steps # Default step limit is n_steps
# Start with a random gw state
self.register_buffer("initial_gw_state", torch.rand(domain_dim))

def set_step_limit(self, step_limit: int):
"""
Sets the step limit for the dynamic attention update loop.
Args:
step_limit (`int`): Maximum number of steps to run the loop.
"""
if step_limit > self.n_steps:
raise ValueError(f"Step limit cannot exceed the maximum n_steps ({self.n_steps}).")
self.step_limit = step_limit

def fuse_weighted_encodings(
self, encodings: LatentsDomainGroupT, attention_dict: dict[str, torch.Tensor]
) -> torch.Tensor:
Expand Down Expand Up @@ -289,7 +301,7 @@ def forward(

if self.n_steps > 0:
# Update the query based on the static attention scores
for _ in range(self.n_steps):
for _ in range(min(self.step_limit, self.n_steps)):
# Apply the attention scores to the encodings
summed_tensor = self.fuse_weighted_encodings(
encodings_pre_fusion, attention_dict
Expand Down

0 comments on commit c3cad6e

Please sign in to comment.