From 4e6d65b3a3a4737c187053e3713abc3020c850be Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Mon, 26 Feb 2024 15:06:51 +0000 Subject: [PATCH] Use sum instead of mean in fusion --- shimmer/modules/gw_module.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 55cf742d..d3651e76 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -384,7 +384,7 @@ def cycle( } -class GWModuleFusion(GWModuleBase): +class GWModuleFusion(GWModule): def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: """ Merge function used to combine domains. @@ -393,7 +393,7 @@ def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: Returns: The merged representation """ - return torch.mean(torch.stack(list(x.values())), dim=0) + return torch.sum(torch.stack(list(x.values())), dim=0) def get_batch_size(self, x: Mapping[str, torch.Tensor]) -> int: for val in x.values(): @@ -422,22 +422,3 @@ def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: for domain in x.keys() } ) - - def decode( - self, z: torch.Tensor, domains: Iterable[str] | None = None - ) -> dict[str, torch.Tensor]: - return { - domain: self.gw_interfaces[domain].decode(z) - for domain in domains or self.gw_interfaces.keys() - } - - def translate(self, x: Mapping[str, torch.Tensor], to: str) -> torch.Tensor: - return self.decode(self.encode(x), domains={to})[to] - - def cycle( - self, x: Mapping[str, torch.Tensor], through: str - ) -> dict[str, torch.Tensor]: - return { - domain: self.translate({through: self.translate(x, through)}, domain) - for domain in x.keys() - }