Skip to content

Commit

Permalink
update docstring for selection-related methods
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Oct 9, 2024
1 parent b6eba5d commit a0371bb
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
3 changes: 2 additions & 1 deletion shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def encode_and_fuse(
Args:
x (`LatentsDomainGroupsT`): the input domain representations.
selection_scores (`Mapping[str, torch.Tensor]`):
selection_module (`SelectionBase`): selection module to use to obtain
selection scores.
Returns:
`dict[frozenset[str], torch.Tensor]`: the GW representations.
Expand Down
4 changes: 2 additions & 2 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def encode_and_fuse(
Args:
x (`LatentsDomainGroupT`): the input domain representations
selection_score (`Mapping[str, torch.Tensor]`): attention scores to
use to encode the reprensetation.
selection_module (`SelectionBase`): selection module to use to obtain
selection scores.
Returns:
`torch.Tensor`: The merged representation.
Expand Down
11 changes: 8 additions & 3 deletions shimmer/modules/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def forward(
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
representation.
Returns:
`dict[str, torch.Tensor]`: for each domain in the group, the fusion
Expand Down Expand Up @@ -75,7 +77,8 @@ def forward(
Args:
domains (`LatentsDomainGroupT`): input unimodal latent representations
gw_state (`torch.Tensor`): the previous GW state
encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
representation.
Returns:
`dict[str, torch.Tensor]`: whether the domain is selected for each input
Expand Down Expand Up @@ -105,7 +108,8 @@ def forward(
Args:
domains (`LatentsDomainGroupT`): input unimodal latent representations
gw_state (`torch.Tensor`): the previous GW state
encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
representation.
Returns:
`dict[str, torch.Tensor]`: whether the domain is selected for each input
Expand Down Expand Up @@ -281,7 +285,8 @@ def forward(
Args:
domains (`LatentsDomainGroupT`): Group of unimodal latent representations.
encodings (`LatentsDomainGroupT`): Group of pre-fusion encodings.
encodings_pre_fusion (`LatentsDomainGroupT`): pre-fusion domain latent
representation.
Returns:
`dict[str, torch.Tensor]`: the attention scores for each domain in the
Expand Down

0 comments on commit a0371bb

Please sign in to comment.