Skip to content

Commit

Permalink
Use sum instead of mean in fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Feb 26, 2024
1 parent 4a5ea2f commit 4e6d65b
Showing 1 changed file with 2 additions and 21 deletions.
23 changes: 2 additions & 21 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def cycle(
}


class GWModuleFusion(GWModuleBase):
class GWModuleFusion(GWModule):
def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
"""
Merge function used to combine domains.
Expand All @@ -393,7 +393,7 @@ def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
Returns:
The merged representation
"""
return torch.mean(torch.stack(list(x.values())), dim=0)
return torch.sum(torch.stack(list(x.values())), dim=0)

def get_batch_size(self, x: Mapping[str, torch.Tensor]) -> int:
for val in x.values():
Expand Down Expand Up @@ -422,22 +422,3 @@ def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
for domain in x.keys()
}
)

def decode(
self, z: torch.Tensor, domains: Iterable[str] | None = None
) -> dict[str, torch.Tensor]:
return {
domain: self.gw_interfaces[domain].decode(z)
for domain in domains or self.gw_interfaces.keys()
}

def translate(self, x: Mapping[str, torch.Tensor], to: str) -> torch.Tensor:
return self.decode(self.encode(x), domains={to})[to]

def cycle(
self, x: Mapping[str, torch.Tensor], through: str
) -> dict[str, torch.Tensor]:
return {
domain: self.translate({through: self.translate(x, through)}, domain)
for domain in x.keys()
}

0 comments on commit 4e6d65b

Please sign in to comment.