Skip to content
This repository has been archived by the owner on Feb 3, 2025. It is now read-only.

Commit

Permalink
Rearrange log spectral distance, add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Jan 13, 2025
1 parent 2150cda commit 088cb2b
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions src/anemoi/training/losses/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,25 @@

LOGGER = logging.getLogger(__name__)


def log_spectral_distance(real_output, fake_output):
epsilon = torch.finfo(torch.float32).eps # Small epsilon to avoid division by zero
power_spectra_real = torch.abs(torch.fft.rfft2(real_output)) ** 2
power_spectra_fake = torch.abs(torch.fft.rfft2(fake_output)) ** 2
ratio = (power_spectra_real + epsilon) / (power_spectra_fake + epsilon)

def log10(x):
return torch.log(x) / torch.log(torch.tensor(10.0, device=x.device, dtype=x.dtype))

result = (10 * log10(ratio)) ** 2
lsd = torch.sqrt(torch.mean(result, dim=(-1, -2, -3))) # Mean over last 3 dimensions
lsd = torch.where(torch.isnan(lsd), torch.zeros_like(lsd), lsd)
return lsd
return result

class LogSpectralDistance(FunctionalWeightedLoss):
"""WeightedLoss which a user can subclass and provide `calculate_difference`.

`calculate_difference` should calculate the difference between the prediction and target.
class LogSpectralDistance(FunctionalWeightedLoss):
"""The log spectral distance is used to compute the difference between spectra of two fields. If is also called log spectral distorsion.
When it is expressed in discrete space with L2 norm, it is defined as:
<math>D_{LS}={\left\{ \frac{1}{N} \sum_{n=1}^N \left[ \log P(n) - \log \hat{P}(n) \right]^2 \right\} }^{1/2} ,</math>.
All scaling and weighting is handled by the parent class.
Example:
--------
```python
class MyLoss(FunctionalWeightedLoss):
def calculate_difference(self, pred, target):
return pred - target
```
"""

def __init__(
Expand All @@ -45,5 +37,18 @@ def __init__(
super().__init__(node_weights, ignore_nans)

def calculate_difference(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Calculate Difference between prediction and target."""
return log_spectral_distance(pred, target)
return log_spectral_distance(pred, target)

def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
squash: bool = True,
*,
scalar_indices: tuple[int, ...] | None = None,
without_scalars: list[str] | list[int] | None = None,
) -> torch.Tensor:
result = super().forward(pred, target, squash, scalar_indices, without_scalars)
lsd = torch.sqrt(torch.mean(result, dim=(-1, -2, -3))) # Mean over last 3 dimensions
lsd = torch.where(torch.isnan(lsd), torch.zeros_like(lsd), lsd)
return lsd

0 comments on commit 088cb2b

Please sign in to comment.