diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index b9642962..ce73c3e4 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -107,37 +107,36 @@ def _translation_loss( for domains, latents in latent_domains.items(): if len(domains) < 2: continue - for domain_name_source in domains: - z = gw_mod.encode( - gw_mod.on_before_gw_encode_tr( - {domain_name_source: latents[domain_name_source]} - ) + for domain_name_target in domains: + + domain_sources = { + domain: latents[domain] + for domain in domains + if domain != domain_name_target + } + + z = gw_mod.encode(gw_mod.on_before_gw_encode_tr(domain_sources)) + mod = domain_mods[domain_name_target] + + domain_source_names = "/".join(domain_sources.keys()) + loss_name = f"{domain_source_names}_to_{domain_name_target}" + if loss_name in losses.keys(): + raise ValueError(f"{loss_name} is already computed.") + + prediction = gw_mod.decode(z, domains={domain_name_target})[ + domain_name_target + ] + loss_output = mod.compute_tr_loss( + prediction, + latents[domain_name_target], + ) + losses[f"translation_{loss_name}"] = loss_output.loss + metrics.update( + { + f"translation_{loss_name}_{k}": v + for k, v in loss_output.metrics.items() + } ) - - for domain_name_target in domains: - if domain_name_source == domain_name_target: - continue - - mod = domain_mods[domain_name_target] - - loss_name = f"{domain_name_source}_to_{domain_name_target}" - if loss_name in losses.keys(): - raise ValueError(f"{loss_name} is already computed.") - - prediction = gw_mod.decode(z, domains={domain_name_target})[ - domain_name_target - ] - loss_output = mod.compute_tr_loss( - prediction, - latents[domain_name_target], - ) - losses[f"translation_{loss_name}"] = loss_output.loss - metrics.update( - { - f"translation_{loss_name}_{k}": v - for k, v in loss_output.metrics.items() - } - ) losses["translations"] = torch.stack(list(losses.values()), dim=0).mean() losses.update(metrics) return losses