Skip to content

Commit

Permalink
Can translate from more than 1 domain
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Feb 26, 2024
1 parent 4299fbd commit ecb876f
Showing 1 changed file with 29 additions and 30 deletions.
59 changes: 29 additions & 30 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,37 +107,36 @@ def _translation_loss(
for domains, latents in latent_domains.items():
if len(domains) < 2:
continue
for domain_name_source in domains:
z = gw_mod.encode(
gw_mod.on_before_gw_encode_tr(
{domain_name_source: latents[domain_name_source]}
)
for domain_name_target in domains:

domain_sources = {
domain: latents[domain]
for domain in domains
if domain != domain_name_target
}

z = gw_mod.encode(gw_mod.on_before_gw_encode_tr(domain_sources))
mod = domain_mods[domain_name_target]

domain_source_names = "/".join(domain_sources.keys())
loss_name = f"{domain_source_names}_to_{domain_name_target}"
if loss_name in losses.keys():
raise ValueError(f"{loss_name} is already computed.")

prediction = gw_mod.decode(z, domains={domain_name_target})[
domain_name_target
]
loss_output = mod.compute_tr_loss(
prediction,
latents[domain_name_target],
)
losses[f"translation_{loss_name}"] = loss_output.loss
metrics.update(
{
f"translation_{loss_name}_{k}": v
for k, v in loss_output.metrics.items()
}
)

for domain_name_target in domains:
if domain_name_source == domain_name_target:
continue

mod = domain_mods[domain_name_target]

loss_name = f"{domain_name_source}_to_{domain_name_target}"
if loss_name in losses.keys():
raise ValueError(f"{loss_name} is already computed.")

prediction = gw_mod.decode(z, domains={domain_name_target})[
domain_name_target
]
loss_output = mod.compute_tr_loss(
prediction,
latents[domain_name_target],
)
losses[f"translation_{loss_name}"] = loss_output.loss
metrics.update(
{
f"translation_{loss_name}_{k}": v
for k, v in loss_output.metrics.items()
}
)
losses["translations"] = torch.stack(list(losses.values()), dim=0).mean()
losses.update(metrics)
return losses
Expand Down

0 comments on commit ecb876f

Please sign in to comment.