Skip to content

Commit

Permalink
Add selection coefs to fusion (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs authored Apr 3, 2024
1 parent 0913906 commit 3ae4c51
Showing 1 changed file with 19 additions and 25 deletions.
44 changes: 19 additions & 25 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,28 +298,24 @@ def __init__(
def fuse(
self,
x: LatentsDomainGroupT,
selection_scores: Mapping[str, torch.Tensor] | None = None,
selection_scores: Mapping[str, torch.Tensor],
) -> torch.Tensor:
"""
Merge function used to combine domains.
Args:
x (`LatentsDomainGroupT`): the group of latent representation.
selection_score (`Mapping[str, torch.Tensor] | None`): attention scores to
selection_score (`Mapping[str, torch.Tensor]`): attention scores to
use to encode the reprensetation.
Returns:
`torch.Tensor`: The merged representation.
"""
return torch.sum(torch.stack(list(x.values())), dim=0)

def encode_and_fuse(
self,
x: LatentsDomainGroupT,
selection_module: SelectionBase | None = None,
) -> torch.Tensor:
encodings = self.encode(x)
selection_scores = selection_module(x, encodings)
return self.fuse(encodings, selection_scores)
return torch.sum(
torch.stack(
[selection_scores[domain] * x[domain] for domain in selection_scores]
),
dim=0,
)

def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT:
"""
Expand Down Expand Up @@ -389,19 +385,24 @@ def __init__(
def fuse(
self,
x: LatentsDomainGroupT,
selection_scores: Mapping[str, torch.Tensor] | None = None,
selection_scores: Mapping[str, torch.Tensor],
) -> torch.Tensor:
"""
Fusion of the pre-fusion GW representations.
Merge function used to combine domains.
Args:
x (`LatentsDomainGroupT`): pre-fusion GW representations.
selection_score (`Mapping[str, torch.Tensor] | None`): attention scores to
x (`LatentsDomainGroupT`): the group of latent representation.
selection_score (`Mapping[str, torch.Tensor]`): attention scores to
use to encode the reprensetation.
Returns:
`torch.Tensor`: the merged GW representation.
`torch.Tensor`: The merged representation.
"""
return torch.sum(torch.stack(list(x.values())), dim=0)
return torch.sum(
torch.stack(
[selection_scores[domain] * x[domain] for domain in selection_scores]
),
dim=0,
)

def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT:
"""
Expand All @@ -418,13 +419,6 @@ def encode(self, x: LatentsDomainGroupT) -> LatentsDomainGroupT:
for domain_name, domain in x.items()
}

def encode_and_fuse(
self, x: LatentsDomainGroupT, selection_module=SelectionBase
) -> torch.Tensor:
encodings = self.encode(x)
selection_scores = selection_module(x, encodings)
return self.fuse(encodings, selection_scores)

def encoded_distribution(
self, x: LatentsDomainGroupT
) -> tuple[LatentsDomainGroupDT, LatentsDomainGroupDT]:
Expand Down

0 comments on commit 3ae4c51

Please sign in to comment.