Skip to content

Commit

Permalink
docs: add metrics returned by broadcast_loss in the docs (#179)
Browse files Browse the repository at this point in the history
This is useful for the more flexible coefs now.

Note, this also renames the metrics of the broadcast loss (group
case is now formatted the same way as the other groups).
  • Loading branch information
bdvllrs authored Oct 15, 2024
1 parent d5c8f5e commit 519d5f2
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,27 @@ def broadcast_loss(
"""
Computes broadcast loss including demi-cycle, cycle, and translation losses.
This return multiple metrics:
* `demi_cycles`
* `cycles`
* `translations`
* `fused`
* `from_{start_group}_to_{domain}_loss` where `{start_group}` is of the form
"{domain1,domain2,domainN}" sorted in alphabetical order
(e.g. "from_{t,v}_to_t_loss").
* `from_{start_group}_to_{domain}_{metric}` with
additional metrics provided by the domain_mod's
`compute_broadcast_loss` output
* `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_loss`
where `{start_group}`, `{target_group}` and `{case_group}` is of the form
"{domain1,domain2,domainN}" sorted in alphabetical order
(e.g. "from_{t}_through_{v}_to_t_case_{t,v}_loss"). `{start_group}` represents the input
domains, `{target_group}` the target domains used for the cycle and
`{case_group}` all available domains participating to the loss.
* `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_{metric}`
additional metrics provided by the domain_mod's `compute_broadcast_loss`
output
Args:
gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
Expand All @@ -534,7 +555,7 @@ def broadcast_loss(
Returns:
A dictionary with the total loss and additional metrics.
"""
""" # noqa: E501
losses: dict[str, torch.Tensor] = {}
metrics: dict[str, torch.Tensor] = {}

Expand All @@ -546,19 +567,18 @@ def broadcast_loss(
for group_domains, latents in latent_domains.items():
encoded_latents = gw_mod.encode(latents)
partitions = generate_partitions(len(group_domains))
domain_names = list(latents)
group_name = "-".join(group_domains)
group_name = "{" + ",".join(sorted(group_domains)) + "}"

for partition in partitions:
selected_latents = {
domain: latents[domain]
for domain, present in zip(domain_names, partition, strict=True)
for domain, present in zip(latents, partition, strict=True)
if present
}
selected_encoded_latents = {
domain: encoded_latents[domain] for domain in selected_latents
}
selected_group_label = "{" + ", ".join(sorted(selected_latents)) + "}"
selected_group_label = "{" + ",".join(sorted(selected_latents)) + "}"

selection_scores = selection_mod(selected_latents, selected_encoded_latents)
fused_latents = gw_mod.fuse(selected_encoded_latents, selection_scores)
Expand Down Expand Up @@ -606,7 +626,7 @@ def broadcast_loss(
}

inverse_selected_group_label = (
"{" + ", ".join(sorted(inverse_selected_latents)) + "}"
"{" + ",".join(sorted(inverse_selected_latents)) + "}"
)

re_encoded_latents = gw_mod.encode(inverse_selected_latents)
Expand Down

0 comments on commit 519d5f2

Please sign in to comment.