From 970418f5080c213a26b86aebffc01ead265f161d Mon Sep 17 00:00:00 2001 From: Benjamin Devillers Date: Wed, 27 Mar 2024 15:39:49 +0100 Subject: [PATCH] Update the Selection module's internal representation of the GW state with an external method (#35) --- shimmer/modules/selection.py | 41 +++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/shimmer/modules/selection.py b/shimmer/modules/selection.py index c48127ee..c4b6c073 100644 --- a/shimmer/modules/selection.py +++ b/shimmer/modules/selection.py @@ -6,7 +6,42 @@ class SelectionBase(torch.nn.Module, ABC): + """ + This is the base class for the selection mechanism. + The selection mechanisms handles the "competition" between modules and *selects* + fusion coefficients for the domains. + """ + + def update_gw_state(self, gw_state: torch.Tensor) -> None: + """ + Update the internal copy of the previous GW state. + By default, this is not implemented and will raise an error if used. + + :note.. + This is not defined as an abstractmethod as some selection method may + not need it. + + Args: + gw_state (`torch.Tensor`): the previous GW state + """ + pass + @abstractmethod - def forward( - self, domains: LatentsDomainGroupT, gw_state: torch.Tensor - ) -> dict[str, torch.Tensor]: ... + def forward(self, domains: LatentsDomainGroupT) -> dict[str, torch.Tensor]: + """ + Forward pass of the selection method. + + Args: + domains (`LatentsDomainGroupT`): Group of unimodal latent representations. + + Returns: + `dict[str, torch.Tensor]`: for each domain in the group, the fusion + coefficient for each item in the batch. + + Example: + >>> SomeSelectionImplementation().forward( + ... {"v": torch.randn(3, 4), "t": torch.randn(3, 8)} + ... ) + {"v": torch.Tensor([0.0, 0.4, 1.0]), "t": torch.Tensor([1.0, 0.6, 0.0])} + """ + ...