Skip to content

Commit

Permalink
if compute_loss returns None, it's skipped + Broadcast loss uses tr, …
Browse files Browse the repository at this point in the history
…dcy, cy and fused losses
  • Loading branch information
bdvllrs committed Aug 13, 2024
1 parent 2545155 commit 991e53f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 14 deletions.
32 changes: 20 additions & 12 deletions shimmer/modules/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,23 @@ def decode(self, z: torch.Tensor) -> Any:
"""
raise NotImplementedError

def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
def compute_loss(
self, pred: torch.Tensor, target: torch.Tensor
) -> LossOutput | None:
"""
Generic loss computation the modality.
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
Results:
`LossOutput`: LossOuput with training loss and additional metrics.
`LossOutput | None`: LossOuput with training loss and additional metrics.
"""
raise NotImplementedError

def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
def compute_dcy_loss(
self, pred: torch.Tensor, target: torch.Tensor
) -> LossOutput | None:
"""
Computes the loss for a demi-cycle. Override if the demi-cycle loss is
different that the generic loss.
Expand All @@ -114,11 +118,13 @@ def compute_dcy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutp
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
Results:
`LossOutput`: LossOuput with training loss and additional metrics.
`LossOutput | None`: LossOuput with training loss and additional metrics.
"""
return self.compute_loss(pred, target)

def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
def compute_cy_loss(
self, pred: torch.Tensor, target: torch.Tensor
) -> LossOutput | None:
"""
Computes the loss for a cycle. Override if the cycle loss is
different that the generic loss.
Expand All @@ -127,11 +133,13 @@ def compute_cy_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
Results:
`LossOutput`: LossOuput with training loss and additional metrics.
`LossOutput | None`: LossOuput with training loss and additional metrics.
"""
return self.compute_loss(pred, target)

def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutput:
def compute_tr_loss(
self, pred: torch.Tensor, target: torch.Tensor
) -> LossOutput | None:
"""
Computes the loss for a translation. Override if the translation loss is
different that the generic loss.
Expand All @@ -140,21 +148,21 @@ def compute_tr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> LossOutpu
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
Results:
`LossOutput`: LossOuput with training loss and additional metrics.
`LossOutput | None`: LossOuput with training loss and additional metrics.
"""
return self.compute_loss(pred, target)

def compute_broadcast_loss(
def compute_fused_loss(
self, pred: torch.Tensor, target: torch.Tensor
) -> LossOutput:
) -> LossOutput | None:
"""
Computes the loss for a broadcast (fusion). Override if the broadcast loss is
Computes the loss for fused (fusion). Override if the fused loss is
different that the generic loss.
Args:
pred (`torch.Tensor`): prediction of the model
target (`torch.Tensor`): target tensor
Results:
`LossOutput`: LossOuput with training loss and additional metrics.
`LossOutput | None`: LossOuput with training loss and additional metrics.
"""
return self.compute_loss(pred, target)
25 changes: 23 additions & 2 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def demi_cycle_loss(
gw_mod.encode_and_fuse(latents, selection_mod), domains={domain_name}
)[domain_name]
loss_output = domain_mod.compute_dcy_loss(x_recons, latents[domain_name])
if loss_output is None:
continue
losses[f"demi_cycle_{domain_name}"] = loss_output.loss
metrics.update(
{f"demi_cycle_{domain_name}_{k}": v for k, v in loss_output.metrics.items()}
Expand Down Expand Up @@ -138,6 +140,9 @@ def cycle_loss(
x_recons[domain_name_source],
latents_source[domain_name_source],
)
if loss_output is None:
continue

metrics.update(
{f"cycle_{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
)
Expand Down Expand Up @@ -200,6 +205,9 @@ def translation_loss(
prediction,
latents[domain_name_target],
)
if loss_output is None:
continue

losses[f"translation_{loss_name}"] = loss_output.loss
metrics.update(
{
Expand Down Expand Up @@ -565,7 +573,18 @@ def broadcast_loss(
if domain not in group_domains: # if we don't have ground truth
continue
ground_truth = latents[domain]
loss_output = domain_mods[domain].compute_loss(pred, ground_truth)

if num_active_domains == 1 and domain in selected_latents:
loss_fn = domain_mods[domain].compute_dcy_loss
elif domain not in selected_latents:
loss_fn = domain_mods[domain].compute_tr_loss
else:
loss_fn = domain_mods[domain].compute_fused_loss

loss_output = loss_fn(pred, ground_truth)
if loss_output is None:
continue

loss_label = f"from_{selected_group_label}_to_{domain}"
losses[loss_label + "_loss"] = loss_output.loss
metrics.update(
Expand Down Expand Up @@ -601,9 +620,11 @@ def broadcast_loss(

for domain in selected_latents:
re_ground_truth = latents[domain]
re_loss_output = domain_mods[domain].compute_loss(
re_loss_output = domain_mods[domain].compute_cy_loss(
re_decoded_latents[domain], re_ground_truth
)
if re_loss_output is None:
continue
loss_label = (
f"from_{selected_group_label}_"
f"through_{inverse_selected_group_label}_to_{domain}_"
Expand Down

0 comments on commit 991e53f

Please sign in to comment.