Skip to content

Commit

Permalink
Some test
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Feb 26, 2024
1 parent 4299fbd commit 8132fca
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 31 deletions.
59 changes: 59 additions & 0 deletions shimmer/modules/gw_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,62 @@ def cycle(
domain: self.translate({through: self.translate(x, through)}, domain)
for domain in x.keys()
}


class GWModuleFusion(GWModuleBase):
def fusion_mechanism(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
"""
Merge function used to combine domains.
Args:
x: mapping of domain name to latent representation.
Returns:
The merged representation
"""
return torch.mean(torch.stack(list(x.values())), dim=0)

def get_batch_size(self, x: Mapping[str, torch.Tensor]) -> int:
for val in x.values():
return val.size(0)
raise ValueError("Got empty dict.")

def get_device(self, x: Mapping[str, torch.Tensor]) -> torch.device:
for val in x.values():
return val.device
raise ValueError("Got empty dict.")

def encode(self, x: Mapping[str, torch.Tensor]) -> torch.Tensor:
domains = {}
bs = self.get_batch_size(x)
device = self.get_device(x)
for domain in self.gw_interfaces.keys():
if domain in x:
domains[domain] = x[domain]
else:
domains[domain] = torch.zeros(
bs, self.gw_interfaces[domain].domain_module.latent_dim
).to(device)
return self.fusion_mechanism(
{
domain: self.gw_interfaces[domain].encode(x[domain])
for domain in x.keys()
}
)

def decode(
self, z: torch.Tensor, domains: Iterable[str] | None = None
) -> dict[str, torch.Tensor]:
return {
domain: self.gw_interfaces[domain].decode(z)
for domain in domains or self.gw_interfaces.keys()
}

def translate(self, x: Mapping[str, torch.Tensor], to: str) -> torch.Tensor:
return self.decode(self.encode(x), domains={to})[to]

def cycle(
self, x: Mapping[str, torch.Tensor], through: str
) -> dict[str, torch.Tensor]:
return {
domain: self.translate({through: self.translate(x, through)}, domain)
for domain in x.keys()
}
187 changes: 156 additions & 31 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Mapping

import torch
import torch.nn.functional as F

from shimmer.modules.contrastive_loss import ContrastiveLossType, VarContrastiveLossType
from shimmer.modules.dict_buffer import DictBuffer
Expand Down Expand Up @@ -107,37 +108,36 @@ def _translation_loss(
for domains, latents in latent_domains.items():
if len(domains) < 2:
continue
for domain_name_source in domains:
z = gw_mod.encode(
gw_mod.on_before_gw_encode_tr(
{domain_name_source: latents[domain_name_source]}
)
for domain_name_target in domains:

domain_sources = {
domain: latents[domain]
for domain in domains
if domain != domain_name_target
}

z = gw_mod.encode(gw_mod.on_before_gw_encode_tr(domain_sources))
mod = domain_mods[domain_name_target]

domain_source_names = "/".join(domain_sources.keys())
loss_name = f"{domain_source_names}_to_{domain_name_target}"
if loss_name in losses.keys():
raise ValueError(f"{loss_name} is already computed.")

prediction = gw_mod.decode(z, domains={domain_name_target})[
domain_name_target
]
loss_output = mod.compute_tr_loss(
prediction,
latents[domain_name_target],
)
losses[f"translation_{loss_name}"] = loss_output.loss
metrics.update(
{
f"translation_{loss_name}_{k}": v
for k, v in loss_output.metrics.items()
}
)

for domain_name_target in domains:
if domain_name_source == domain_name_target:
continue

mod = domain_mods[domain_name_target]

loss_name = f"{domain_name_source}_to_{domain_name_target}"
if loss_name in losses.keys():
raise ValueError(f"{loss_name} is already computed.")

prediction = gw_mod.decode(z, domains={domain_name_target})[
domain_name_target
]
loss_output = mod.compute_tr_loss(
prediction,
latents[domain_name_target],
)
losses[f"translation_{loss_name}"] = loss_output.loss
metrics.update(
{
f"translation_{loss_name}_{k}": v
for k, v in loss_output.metrics.items()
}
)
losses["translations"] = torch.stack(list(losses.values()), dim=0).mean()
losses.update(metrics)
return losses
Expand All @@ -153,7 +153,7 @@ def _contrastive_loss(
keys: list[set[str]] = []

for latents in latent_domains.values():
if len(latents) < 2:
if len(latents) != 2:
continue
for domain1_name, domain1 in latents.items():
z1 = gw_mod.encode(gw_mod.on_before_gw_encode_cont({domain1_name: domain1}))
Expand Down Expand Up @@ -354,3 +354,128 @@ def step(
).mean()

return LossOutput(loss, metrics)


def sample_scaling_factors(
binary_scaling_prob: float,
batch_size: int,
temperature: float,
device: torch.device,
):
"""
Args:
binary_scaling_prob: float
batch_size: int
temperature: float greater than 0
"""
assert 0 <= binary_scaling_prob <= 1

# TODO: make selection deterministic
binary_mask = torch.rand(batch_size) < binary_scaling_prob

binary_factors = torch.randint(0, 2, (batch_size,)).float()
binary_softmax = torch.stack([binary_factors, 1 - binary_factors], dim=1)

uniform_samples = torch.rand(batch_size)
uniform_for_softmax = torch.stack([uniform_samples, 1 - uniform_samples], dim=1)

uniform_softmax = F.softmax(uniform_for_softmax * temperature, dim=1)

scaling_factors = torch.where(
binary_mask.unsqueeze(-1), binary_softmax, uniform_softmax
).to(device)

binary_indices = torch.where(binary_mask)[0]
softmax_indices = torch.where(~binary_mask)[0]

binary_scaling_factors = scaling_factors[binary_indices]
softmax_scaling_factors = scaling_factors[softmax_indices]

return {
"binary": (
binary_scaling_factors[:, 0],
binary_scaling_factors[:, 1],
binary_indices,
),
"softmax": (
softmax_scaling_factors[:, 0],
softmax_scaling_factors[:, 1],
softmax_indices,
),
}


class GWFusionLosses(GWLossesBase):
def __init__(
self,
gw_mod: GWModule,
domain_mods: dict[str, DomainModule],
coef_buffers: DictBuffer,
contrastive_fn: ContrastiveLossType,
):
super().__init__()
self.gw_mod = gw_mod
self.domain_mods = domain_mods
self.loss_coefs = coef_buffers
self.contrastive_fn = contrastive_fn

def demi_cycle_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]:
return _demi_cycle_loss(self.gw_mod, self.domain_mods, latent_domains)

def cycle_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]:
return _cycle_loss(self.gw_mod, self.domain_mods, latent_domains)

def translation_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]:
return _translation_loss(self.gw_mod, self.domain_mods, latent_domains)

def contrastive_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]:
return _contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)

def broadcast_loss(self, latent_domains: LatentsT) -> dict[str, torch.Tensor]:
losses: dict[str, torch.Tensor] = {}
metrics: dict[str, torch.Tensor] = {}
keys: list[set[str]] = []

for latents in latent_domains.values():
if len(latents) < 2:
continue
for domain1_name, domain1 in latents.items():
z1 = gw_mod.encode(
gw_mod.on_before_gw_encode_cont({domain1_name: domain1})
)
for domain2_name, domain2 in latents.items():
selected_domains = {domain1_name, domain2_name}
if domain1_name == domain2_name or selected_domains in keys:
continue

keys.append(selected_domains)

loss_name = f"contrastive_{domain1_name}_and_{domain2_name}"
z2 = gw_mod.encode(
gw_mod.on_before_gw_encode_cont({domain2_name: domain2})
)
loss_output = contrastive_fn(z1, z2)
losses[loss_name] = loss_output.loss
metrics.update(
{f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
)

losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean()
losses.update(metrics)
return losses

def step(
self,
domain_latents: Mapping[frozenset[str], Mapping[str, torch.Tensor]],
) -> LossOutput:
metrics: dict[str, torch.Tensor] = {}

metrics.update(self.demi_cycle_loss(domain_latents))
metrics.update(self.cycle_loss(domain_latents))
metrics.update(self.translation_loss(domain_latents))
metrics.update(self.contrastive_loss(domain_latents))
metrics.update(self.broadcast_loss(domain_latents))

loss = metrics["broadcast"]

return LossOutput(loss, metrics)

0 comments on commit 8132fca

Please sign in to comment.