Skip to content

Commit

Permalink
Update the Selection module's internal representation of the GW state…
Browse files Browse the repository at this point in the history
… with an external method (#35)
  • Loading branch information
bdvllrs authored Mar 27, 2024
1 parent 9cd111d commit 970418f
Showing 1 changed file with 38 additions and 3 deletions.
41 changes: 38 additions & 3 deletions shimmer/modules/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,42 @@


class SelectionBase(torch.nn.Module, ABC):
"""
This is the base class for the selection mechanism.
The selection mechanisms handles the "competition" between modules and *selects*
fusion coefficients for the domains.
"""

def update_gw_state(self, gw_state: torch.Tensor) -> None:
"""
Update the internal copy of the previous GW state.
By default, this is not implemented and will raise an error if used.
:note..
This is not defined as an abstractmethod as some selection method may
not need it.
Args:
gw_state (`torch.Tensor`): the previous GW state
"""
pass

@abstractmethod
def forward(
self, domains: LatentsDomainGroupT, gw_state: torch.Tensor
) -> dict[str, torch.Tensor]: ...
def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]:
"""
Forward pass of the selection method.
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
Returns:
`dict[str, torch.Tensor]`: for each domain in the group, the fusion
coefficient for each item in the batch.
Example:
>>> SomeSelectionImplementation().forward(
... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)}
... )
{"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])}
"""
...

0 comments on commit 970418f

Please sign in to comment.