Skip to content
This repository has been archived by the owner on Feb 3, 2025. It is now read-only.

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Jan 31, 2025
1 parent 0fc1b4e commit 8e21b45
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 4,945 deletions.
7 changes: 1 addition & 6 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@

import numpy as np
import pytorch_lightning as pl
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader

from anemoi.datasets.data import open_dataset
from anemoi.models.data_indices.collection import IndexCollection
from anemoi.training.data.dataset import NativeGridDataset, worker_init_func
from anemoi.utils.dates import frequency_to_seconds
from hydra.utils import instantiate
from omegaconf import DictConfig
Expand Down Expand Up @@ -127,7 +122,7 @@ def relative_date_indices(self) -> list:
set(range(multi_step)).union(
[t + multi_step - 1 for t in self.config.training.explicit_times.input],
[t + multi_step - 1 for t in self.config.training.explicit_times.target],
)
),
)

# uses the old default of multistep, timeincrement and rollout.
Expand Down
18 changes: 8 additions & 10 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,11 +1093,11 @@ def _plot(
for name in self.parameters
}
plot_parameters_target_dict = {
pl_module.data_indices.data.output.name_to_index[name]: (
name,
name not in diagnostics,
)
for name in self.parameters
pl_module.data_indices.data.output.name_to_index[name]: (
name,
name not in diagnostics,
)
for name in self.parameters
}

fig = plot_power_spectrum(
Expand All @@ -1107,7 +1107,7 @@ def _plot(
data[rollout_step + 1, ...].squeeze(),
output_tensor[rollout_step, ...],
min_delta=self.min_delta,
parameters_target=plot_parameters_target_dict
parameters_target=plot_parameters_target_dict,
)

self._output_figure(
Expand Down Expand Up @@ -1206,8 +1206,6 @@ def _plot(
logger,
fig,
epoch=epoch,
tag=
f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0",
exp_log_tag=
f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}",
tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0",
exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}",
)
49 changes: 11 additions & 38 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,19 +320,14 @@ def plot_histogram(
n_plots_x, n_plots_y = len(parameters), 1

figsize = (n_plots_y * 4, n_plots_x * 3)
fig, ax = plt.subplots(n_plots_x,
n_plots_y,
figsize=figsize,
layout=LAYOUT)
fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT)
if n_plots_x == 1:
ax = [ax]

for plot_idx, (variable_idx,
(variable_name,
output_only)) in enumerate(parameters.items()):
variable_batch_index = ([
k for k, (i, _) in parameters_target.items() if i == variable_name
] if parameters_target else variable_idx)
for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()):
variable_batch_index = (
[k for k, (i, _) in parameters_target.items() if i == variable_name] if parameters_target else variable_idx
)
yt = y_true[..., variable_batch_index].squeeze()
yp = y_pred[..., variable_idx].squeeze()
# postprocessed outputs so we need to handle possible NaNs
Expand All @@ -346,45 +341,23 @@ def plot_histogram(
# enforce the same binning for both histograms
bin_min = min(np.nanmin(yt_xt), np.nanmin(yp_xt))
bin_max = max(np.nanmax(yt_xt), np.nanmax(yp_xt))
hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)],
bins=100,
density=True,
range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)],
bins=100,
density=True,
range=[bin_min, bin_max])
hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, density=True, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, density=True, range=[bin_min, bin_max])
else:
# enforce the same binning for both histograms
bin_min = min(np.nanmin(yt), np.nanmin(yp))
bin_max = max(np.nanmax(yt), np.nanmax(yp))
hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)],
bins=100,
density=True,
range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)],
bins=100,
density=True,
range=[bin_min, bin_max])
hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, density=True, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, density=True, range=[bin_min, bin_max])

# Visualization trick for tp
if variable_name in precip_and_related_fields:
# in-place multiplication does not work here because variables are different numpy types
hist_yt = hist_yt * bins_yt[:-1]
hist_yp = hist_yp * bins_yp[:-1]
# Plot the modified histogram
ax[plot_idx].bar(bins_yt[:-1],
hist_yt,
width=np.diff(bins_yt),
color="blue",
alpha=0.7,
label="Truth (data)")
ax[plot_idx].bar(bins_yp[:-1],
hist_yp,
width=np.diff(bins_yp),
color="red",
alpha=0.7,
label="Predicted")
ax[plot_idx].bar(bins_yt[:-1], hist_yt, width=np.diff(bins_yt), color="blue", alpha=0.7, label="Truth (data)")
ax[plot_idx].bar(bins_yp[:-1], hist_yp, width=np.diff(bins_yp), color="red", alpha=0.7, label="Predicted")

ax[plot_idx].set_title(variable_name)
ax[plot_idx].set_xlabel(variable_name)
Expand Down
9 changes: 4 additions & 5 deletions src/anemoi/training/losses/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
self,
losses: Sequence[torch.nn.Module],
loss_weights: tuple[int, ...],
**kwargs,
):
"""Combined loss function.
Expand Down Expand Up @@ -106,10 +105,10 @@ def forward(
loss = None
for i, loss_fn in enumerate(self.losses):
sub_loss = self.loss_weights[i] * loss_fn(
pred,
target,
**kwargs,
)
pred,
target,
**kwargs,
)
if loss is not None:
loss += sub_loss.expand_as(loss)
else:
Expand Down
17 changes: 7 additions & 10 deletions src/anemoi/training/losses/timeweights.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import logging

import torch
from anemoi.graphs.nodes.attributes import AreaWeights
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,15 +61,15 @@ def forward_weights(self, relative_date_indices: list[int]) -> torch.Tensor:
"""
if self.method == "exponential":
return torch.exp(-self.decay_factor * torch.tensor(relative_date_indices))
elif self.method == "linear":
if self.method == "linear":
return (
1
- (1 - self.decay_factor)
* torch.tensor(relative_date_indices)
/ torch.tensor(relative_date_indices).max()
)
else:
raise ValueError(f"Method {self.method} not supported")
msg = f"Method {self.method} not supported"
raise NotImplementedError(msg)

def backward_weights(self, relative_date_indices: list[int]) -> torch.Tensor:
"""Returns weight of type self.node_attribute for nodes self.target.
Expand All @@ -91,14 +89,14 @@ def backward_weights(self, relative_date_indices: list[int]) -> torch.Tensor:
"""
if self.method == "exponential":
return torch.exp(self.decay_factor * torch.tensor(relative_date_indices))
elif self.method == "linear":
if self.method == "linear":
return (
(1 - self.decay_factor)
* torch.tensor(relative_date_indices)
/ torch.tensor(relative_date_indices).max()
)
else:
raise ValueError(f"Method {self.method} not supported")
msg = f"Method {self.method} not supported"
raise NotImplementedError(msg)

def weights(self, relative_date_indices: list[int]) -> torch.Tensor:
"""Returns weight of type self.node_attribute for nodes self.target.
Expand All @@ -118,5 +116,4 @@ def weights(self, relative_date_indices: list[int]) -> torch.Tensor:
"""
if self.inverse:
return self.backward_weights(relative_date_indices)
else:
return self.forward_weights(relative_date_indices)
return self.forward_weights(relative_date_indices)
2 changes: 1 addition & 1 deletion src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def full_name(type_: type) -> str:
if node_weights.dtype == torch.bool:
node_weights = node_weights / node_weights.sum()
kwargs["node_weights"] = node_weights

if config.get("time_weights", None) is not None:
time_weights = instantiate(config.time_weights)
time_weights = time_weights.weights(self.relative_date_indices)
Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/training/train/interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@


import logging
from einops import rearrange
from collections.abc import Mapping
from operator import itemgetter

import torch
from anemoi.models.data_indices.collection import IndexCollection
from einops import rearrange
from omegaconf import DictConfig
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
set(range(self.multi_step)).union(
self.boundary_times,
self.interp_times,
)
),
)
self.imap = {data_index: batch_index for batch_index, data_index in enumerate(sorted_indices)}

Expand Down Expand Up @@ -111,8 +111,8 @@ def _step(
time_weights = self.loss.losses[0].loss.time_weights
for interp_step in self.interp_times:
# update time weights in loss function for this specific case
for l in self.loss.losses:
l.loss.time_weights = time_weights[self.imap[interp_step]]
for specific_loss in self.loss.losses:
specific_loss.loss.time_weights = time_weights[self.imap[interp_step]]
# get the forcing information for the target interpolation time:
target_forcing[..., : len(kfv)] = batch[:, self.imap[interp_step], :, :, kfv]
target_forcing[..., -1] = (interp_step - future) / (future - present)
Expand Down
Loading

0 comments on commit 8e21b45

Please sign in to comment.