From 519d5f284d6b37c05e3317f7f8d0abbc4bd651ea Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Tue, 15 Oct 2024 14:18:06 +0200 Subject: [PATCH] docs: add metrics returned by `broadcast_loss` in the docs (#179) This is useful for the more flexible coefs now. Note, this also renames the metrics of the broadcast loss (group case is now formatted the same way as the other groups). --- shimmer/modules/losses.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index ef60f8d..e47c029 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -525,6 +525,27 @@ def broadcast_loss( """ Computes broadcast loss including demi-cycle, cycle, and translation losses. + This return multiple metrics: + * `demi_cycles` + * `cycles` + * `translations` + * `fused` + * `from_{start_group}_to_{domain}_loss` where `{start_group}` is of the form + "{domain1,domain2,domainN}" sorted in alphabetical order + (e.g. "from_{t,v}_to_t_loss"). + * `from_{start_group}_to_{domain}_{metric}` with + additional metrics provided by the domain_mod's + `compute_broadcast_loss` output + * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_loss` + where `{start_group}`, `{target_group}` and `{case_group}` is of the form + "{domain1,domain2,domainN}" sorted in alphabetical order + (e.g. "from_{t}_through_{v}_to_t_case_{t,v}_loss"). `{start_group}` represents the input + domains, `{target_group}` the target domains used for the cycle and + `{case_group}` all available domains participating to the loss. + * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_{metric}` + additional metrics provided by the domain_mod's `compute_broadcast_loss` + output + Args: gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use @@ -534,7 +555,7 @@ def broadcast_loss( Returns: A dictionary with the total loss and additional metrics. - """ + """ # noqa: E501 losses: dict[str, torch.Tensor] = {} metrics: dict[str, torch.Tensor] = {} @@ -546,19 +567,18 @@ def broadcast_loss( for group_domains, latents in latent_domains.items(): encoded_latents = gw_mod.encode(latents) partitions = generate_partitions(len(group_domains)) - domain_names = list(latents) - group_name = "-".join(group_domains) + group_name = "{" + ",".join(sorted(group_domains)) + "}" for partition in partitions: selected_latents = { domain: latents[domain] - for domain, present in zip(domain_names, partition, strict=True) + for domain, present in zip(latents, partition, strict=True) if present } selected_encoded_latents = { domain: encoded_latents[domain] for domain in selected_latents } - selected_group_label = "{" + ", ".join(sorted(selected_latents)) + "}" + selected_group_label = "{" + ",".join(sorted(selected_latents)) + "}" selection_scores = selection_mod(selected_latents, selected_encoded_latents) fused_latents = gw_mod.fuse(selected_encoded_latents, selection_scores) @@ -606,7 +626,7 @@ def broadcast_loss( } inverse_selected_group_label = ( - "{" + ", ".join(sorted(inverse_selected_latents)) + "}" + "{" + ",".join(sorted(inverse_selected_latents)) + "}" ) re_encoded_latents = gw_mod.encode(inverse_selected_latents)