Skip to content

Commit

Permalink
add docstring to explain what returning None can offer
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Sep 19, 2024
1 parent 15fae9b commit 897ffd4
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions shimmer/modules/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def compute_loss(
target (`torch.Tensor`): target tensor
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss.
"""
raise NotImplementedError

Expand All @@ -119,6 +121,9 @@ def compute_dcy_loss(
target (`torch.Tensor`): target tensor
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss; it can be used to deactivate
demi-cycle loss for this domain.
"""
return self.compute_loss(pred, target)

Expand All @@ -134,6 +139,9 @@ def compute_cy_loss(
target (`torch.Tensor`): target tensor
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss; it can be used to deactivate
cycle loss for this domain.
"""
return self.compute_loss(pred, target)

Expand All @@ -149,6 +157,9 @@ def compute_tr_loss(
target (`torch.Tensor`): target tensor
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss; it can be used to deactivate
translation loss for this domain.
"""
return self.compute_loss(pred, target)

Expand All @@ -164,5 +175,8 @@ def compute_fused_loss(
target (`torch.Tensor`): target tensor
Results:
`LossOutput | None`: LossOuput with training loss and additional metrics.
If `None` is returned, this loss will be ignored and will not
participate in the total loss; it can be used to deactivate
fused loss for this domain.
"""
return self.compute_loss(pred, target)

0 comments on commit 897ffd4

Please sign in to comment.