From a285d36844c367e9a370d8180bae6ed3961d4fa4 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Tue, 13 Aug 2024 10:03:11 +0000 Subject: [PATCH] if compute_loss returns None, it's skipped + Broadcast loss uses tr, dcy, cy and fused losses --- shimmer/modules/domain.py | 32 ++++++++++++++++++++------------ shimmer/modules/losses.py | 25 +++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/shimmer/modules/domain.py b/shimmer/modules/domain.py index 8833ed32..c4b67ea6 100644 --- a/shimmer/modules/domain.py +++ b/shimmer/modules/domain.py @@ -93,7 +93,9 @@ def decode(self, z: torch.Tensor) -> Any: """ raise NotImplementedError - def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_loss( + self, pred: torch.Tensor, target: torch.Tensor + ) -> LossOutput | None: """ Generic loss computation the modality. @@ -101,11 +103,13 @@ def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor Results: - `LossOutput`: LossOuput with training loss and additional metrics. + `LossOutput | None`: LossOuput with training loss and additional metrics. """ raise NotImplementedError - def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_dcy_loss( + self, pred: torch.Tensor, target: torch.Tensor + ) -> LossOutput | None: """ Computes the loss for a demi-cycle. Override if the demi-cycle loss is different that the generic loss. @@ -114,11 +118,13 @@ def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutp pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor Results: - `LossOutput`: LossOuput with training loss and additional metrics. + `LossOutput | None`: LossOuput with training loss and additional metrics. """ return self.compute_loss(pred, target) - def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_cy_loss( + self, pred: torch.Tensor, target: torch.Tensor + ) -> LossOutput | None: """ Computes the loss for a cycle. Override if the cycle loss is different that the generic loss. @@ -127,11 +133,13 @@ def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor Results: - `LossOutput`: LossOuput with training loss and additional metrics. + `LossOutput | None`: LossOuput with training loss and additional metrics. """ return self.compute_loss(pred, target) - def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput: + def compute_tr_loss( + self, pred: torch.Tensor, target: torch.Tensor + ) -> LossOutput | None: """ Computes the loss for a translation. Override if the translation loss is different that the generic loss. @@ -140,21 +148,21 @@ def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor Results: - `LossOutput`: LossOuput with training loss and additional metrics. + `LossOutput | None`: LossOuput with training loss and additional metrics. """ return self.compute_loss(pred, target) - def compute_broadcast_loss( + def compute_fused_loss( self, pred: torch.Tensor, target: torch.Tensor - ) -> LossOutput: + ) -> LossOutput | None: """ - Computes the loss for a broadcast (fusion). Override if the broadcast loss is + Computes the loss for fused (fusion). Override if the fused loss is different that the generic loss. Args: pred (`torch.Tensor`): prediction of the model target (`torch.Tensor`): target tensor Results: - `LossOutput`: LossOuput with training loss and additional metrics. + `LossOutput | None`: LossOuput with training loss and additional metrics. """ return self.compute_loss(pred, target) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index 08273a93..49b60924 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -78,6 +78,8 @@ def demi_cycle_loss( gw_mod.encode_and_fuse(latents, selection_mod), domains={domain_name} )[domain_name] loss_output = domain_mod.compute_dcy_loss(x_recons, latents[domain_name]) + if loss_output is None: + continue losses[f"demi_cycle_{domain_name}"] = loss_output.loss metrics.update( {f"demi_cycle_{domain_name}_{k}": v for k, v in loss_output.metrics.items()} @@ -138,6 +140,9 @@ def cycle_loss( x_recons[domain_name_source], latents_source[domain_name_source], ) + if loss_output is None: + continue + metrics.update( {f"cycle_{loss_name}_{k}": v for k, v in loss_output.metrics.items()} ) @@ -200,6 +205,9 @@ def translation_loss( prediction, latents[domain_name_target], ) + if loss_output is None: + continue + losses[f"translation_{loss_name}"] = loss_output.loss metrics.update( { @@ -565,7 +573,18 @@ def broadcast_loss( if domain not in group_domains: # if we don't have ground truth continue ground_truth = latents[domain] - loss_output = domain_mods[domain].compute_loss(pred, ground_truth) + + if num_active_domains == 1 and domain in selected_latents: + loss_fn = domain_mods[domain].compute_dcy_loss + elif domain not in selected_latents: + loss_fn = domain_mods[domain].compute_tr_loss + else: + loss_fn = domain_mods[domain].compute_fused_loss + + loss_output = loss_fn(pred, ground_truth) + if loss_output is None: + continue + loss_label = f"from_{selected_group_label}_to_{domain}" losses[loss_label + "_loss"] = loss_output.loss metrics.update( @@ -601,9 +620,11 @@ def broadcast_loss( for domain in selected_latents: re_ground_truth = latents[domain] - re_loss_output = domain_mods[domain].compute_loss( + re_loss_output = domain_mods[domain].compute_cy_loss( re_decoded_latents[domain], re_ground_truth ) + if re_loss_output is None: + continue loss_label = ( f"from_{selected_group_label}_" f"through_{inverse_selected_group_label}_to_{domain}_"