diff --git a/shimmer/modules/gw_module.py b/shimmer/modules/gw_module.py index 5e0b25bd..55cf742d 100644 --- a/shimmer/modules/gw_module.py +++ b/shimmer/modules/gw_module.py @@ -382,3 +382,62 @@ def cycle( domain: self.translate({through: self.translate(x, through)}, domain) for domain in x.keys() } + + +class GWModuleFusion(GWModuleBase): + def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: + """ + Merge function used to combine domains. + Args: + x: mapping of domain name to latent representation. + Returns: + The merged representation + """ + return torch.mean(torch.stack(list(x.values())), dim=0) + + def get_batch_size(self, x: Mapping[str, torch.Tensor]) -> int: + for val in x.values(): + return val.size(0) + raise ValueError("Got empty dict.") + + def get_device(self, x: Mapping[str, torch.Tensor]) -> torch.device: + for val in x.values(): + return val.device + raise ValueError("Got empty dict.") + + def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor: + domains = {} + bs = self.get_batch_size(x) + device = self.get_device(x) + for domain in self.gw_interfaces.keys(): + if domain in x: + domains[domain] = x[domain] + else: + domains[domain] = torch.zeros( + bs, self.gw_interfaces[domain].domain_module.latent_dim + ).to(device) + return self.fusion_mechanism( + { + domain: self.gw_interfaces[domain].encode(x[domain]) + for domain in x.keys() + } + ) + + def decode( + self, z: torch.Tensor, domains: Iterable[str] | None = None + ) -> dict[str, torch.Tensor]: + return { + domain: self.gw_interfaces[domain].decode(z) + for domain in domains or self.gw_interfaces.keys() + } + + def translate(self, x: Mapping[str, torch.Tensor], to: str) -> torch.Tensor: + return self.decode(self.encode(x), domains={to})[to] + + def cycle( + self, x: Mapping[str, torch.Tensor], through: str + ) -> dict[str, torch.Tensor]: + return { + domain: self.translate({through: self.translate(x, through)}, domain) + for domain in x.keys() + } diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index b9642962..8a66b010 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -2,6 +2,7 @@ from collections.abc import Mapping import torch +import torch.nn.functional as F from shimmer.modules.contrastive_loss import ContrastiveLossType, VarContrastiveLossType from shimmer.modules.dict_buffer import DictBuffer @@ -107,37 +108,36 @@ def _translation_loss( for domains, latents in latent_domains.items(): if len(domains) < 2: continue - for domain_name_source in domains: - z = gw_mod.encode( - gw_mod.on_before_gw_encode_tr( - {domain_name_source: latents[domain_name_source]} - ) + for domain_name_target in domains: + + domain_sources = { + domain: latents[domain] + for domain in domains + if domain != domain_name_target + } + + z = gw_mod.encode(gw_mod.on_before_gw_encode_tr(domain_sources)) + mod = domain_mods[domain_name_target] + + domain_source_names = "/".join(domain_sources.keys()) + loss_name = f"{domain_source_names}_to_{domain_name_target}" + if loss_name in losses.keys(): + raise ValueError(f"{loss_name} is already computed.") + + prediction = gw_mod.decode(z, domains={domain_name_target})[ + domain_name_target + ] + loss_output = mod.compute_tr_loss( + prediction, + latents[domain_name_target], + ) + losses[f"translation_{loss_name}"] = loss_output.loss + metrics.update( + { + f"translation_{loss_name}_{k}": v + for k, v in loss_output.metrics.items() + } ) - - for domain_name_target in domains: - if domain_name_source == domain_name_target: - continue - - mod = domain_mods[domain_name_target] - - loss_name = f"{domain_name_source}_to_{domain_name_target}" - if loss_name in losses.keys(): - raise ValueError(f"{loss_name} is already computed.") - - prediction = gw_mod.decode(z, domains={domain_name_target})[ - domain_name_target - ] - loss_output = mod.compute_tr_loss( - prediction, - latents[domain_name_target], - ) - losses[f"translation_{loss_name}"] = loss_output.loss - metrics.update( - { - f"translation_{loss_name}_{k}": v - for k, v in loss_output.metrics.items() - } - ) losses["translations"] = torch.stack(list(losses.values()), dim=0).mean() losses.update(metrics) return losses @@ -153,7 +153,7 @@ def _contrastive_loss( keys: list[set[str]] = [] for latents in latent_domains.values(): - if len(latents) < 2: + if len(latents) != 2: continue for domain1_name, domain1 in latents.items(): z1 = gw_mod.encode(gw_mod.on_before_gw_encode_cont({domain1_name: domain1})) @@ -354,3 +354,128 @@ def step( ).mean() return LossOutput(loss, metrics) + + +def sample_scaling_factors( + binary_scaling_prob: float, + batch_size: int, + temperature: float, + device: torch.device, +): + """ + Args: + binary_scaling_prob: float + batch_size: int + temperature: float greater than 0 + """ + assert 0 <= binary_scaling_prob <= 1 + + # TODO: make selection deterministic + binary_mask = torch.rand(batch_size) < binary_scaling_prob + + binary_factors = torch.randint(0, 2, (batch_size,)).float() + binary_softmax = torch.stack([binary_factors, 1 - binary_factors], dim=1) + + uniform_samples = torch.rand(batch_size) + uniform_for_softmax = torch.stack([uniform_samples, 1 - uniform_samples], dim=1) + + uniform_softmax = F.softmax(uniform_for_softmax * temperature, dim=1) + + scaling_factors = torch.where( + binary_mask.unsqueeze(-1), binary_softmax, uniform_softmax + ).to(device) + + binary_indices = torch.where(binary_mask)[0] + softmax_indices = torch.where(~binary_mask)[0] + + binary_scaling_factors = scaling_factors[binary_indices] + softmax_scaling_factors = scaling_factors[softmax_indices] + + return { + "binary": ( + binary_scaling_factors[:, 0], + binary_scaling_factors[:, 1], + binary_indices, + ), + "softmax": ( + softmax_scaling_factors[:, 0], + softmax_scaling_factors[:, 1], + softmax_indices, + ), + } + + +class GWFusionLosses(GWLossesBase): + def __init__( + self, + gw_mod: GWModule, + domain_mods: dict[str, DomainModule], + coef_buffers: DictBuffer, + contrastive_fn: ContrastiveLossType, + ): + super().__init__() + self.gw_mod = gw_mod + self.domain_mods = domain_mods + self.loss_coefs = coef_buffers + self.contrastive_fn = contrastive_fn + + def demi_cycle_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + return _demi_cycle_loss(self.gw_mod, self.domain_mods, latent_domains) + + def cycle_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + return _cycle_loss(self.gw_mod, self.domain_mods, latent_domains) + + def translation_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + return _translation_loss(self.gw_mod, self.domain_mods, latent_domains) + + def contrastive_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + return _contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn) + + def broadcast_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]: + losses: dict[str, torch.Tensor] = {} + metrics: dict[str, torch.Tensor] = {} + keys: list[set[str]] = [] + + for latents in latent_domains.values(): + if len(latents) < 2: + continue + for domain1_name, domain1 in latents.items(): + z1 = gw_mod.encode( + gw_mod.on_before_gw_encode_cont({domain1_name: domain1}) + ) + for domain2_name, domain2 in latents.items(): + selected_domains = {domain1_name, domain2_name} + if domain1_name == domain2_name or selected_domains in keys: + continue + + keys.append(selected_domains) + + loss_name = f"contrastive_{domain1_name}_and_{domain2_name}" + z2 = gw_mod.encode( + gw_mod.on_before_gw_encode_cont({domain2_name: domain2}) + ) + loss_output = contrastive_fn(z1, z2) + losses[loss_name] = loss_output.loss + metrics.update( + {f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()} + ) + + losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean() + losses.update(metrics) + return losses + + def step( + self, + domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]], + ) -> LossOutput: + metrics: dict[str, torch.Tensor] = {} + + metrics.update(self.demi_cycle_loss(domain_latents)) + metrics.update(self.cycle_loss(domain_latents)) + metrics.update(self.translation_loss(domain_latents)) + metrics.update(self.contrastive_loss(domain_latents)) + metrics.update(self.broadcast_loss(domain_latents)) + + loss = metrics["broadcast"] + + return LossOutput(loss, metrics)