Skip to content
This repository has been archived by the owner on Feb 3, 2025. It is now read-only.

Commit

Permalink
Merge Interpolator PR
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Jan 6, 2025
2 parents baa6e1f + f4f7797 commit 6aa4d1b
Show file tree
Hide file tree
Showing 6 changed files with 289 additions and 80 deletions.
118 changes: 93 additions & 25 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from functools import cached_property
from typing import TYPE_CHECKING
from typing import Callable

import numpy as np
import os
import pytorch_lightning as pl
from anemoi.datasets.data import open_dataset
from anemoi.models.data_indices.collection import IndexCollection
Expand Down Expand Up @@ -49,14 +50,35 @@ def __init__(self, config: DictConfig, graph_data: HeteroData) -> None:
super().__init__()

self.config = config

self.global_rank = int(os.environ.get("SLURM_PROCID",
"0")) # global rank
self.model_comm_group_id = (
self.global_rank // self.config.hardware.num_gpus_per_model
) # id of the model communication group the rank is participating in
self.model_comm_group_rank = (
self.global_rank % self.config.hardware.num_gpus_per_model
) # rank within one model communication group
total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes
assert (total_gpus) % self.config.hardware.num_gpus_per_model == 0, (
f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}"
)
self.model_comm_num_groups = (self.config.hardware.num_gpus_per_node *
self.config.hardware.num_nodes //
self.config.hardware.num_gpus_per_model
) # number of model communication groups
LOGGER.debug(
"Rank %d model communication group number %d, with local model communication group rank %d",
self.global_rank,
self.model_comm_group_id,
self.model_comm_group_rank,
)
self.graph_data = graph_data

# Set the maximum rollout to be expected
self.rollout = (
self.config.training.rollout.max
if self.config.training.rollout.epoch_increment > 0
else self.config.training.rollout.start
)
self.rollout = (self.config.training.rollout.max
if self.config.training.rollout.epoch_increment > 0
else self.config.training.rollout.start)

# Set the training end date if not specified
if self.config.dataloader.training.end is None:
Expand Down Expand Up @@ -85,10 +107,55 @@ def supporting_arrays(self) -> dict:
def data_indices(self) -> IndexCollection:
return IndexCollection(self.config, self.ds_train.name_to_index)

@cached_property
def relative_date_indices(self) -> list:
"""Determine a list of relative time indices to load for each batch"""
if hasattr(self.config.training, "explicit_times"):
return sorted(
set(self.config.training.explicit_times.input +
self.config.training.explicit_times.target))

else: #uses the old default of multistep, timeincrement and rollout.
# Use the maximum rollout to be expected
rollout = (
self.config.training.rollout.max
if self.config.training.rollout.epoch_increment > 0 else
self.config.training.rollout.start
) #NOTE: --> for gradual rollout, max rollout dates is always fetched. But this was always the case in datamodule.py

multi_step = self.config.training.multistep_input
return [
self.timeincrement * mstep
for mstep in range(multi_step + rollout)
]

def add_model_run_ids(self, data_reader):
"""Determine the model run id of each time index of the data and add to a data_reader object
NOTE/TODO: This is only relevant when training on non-analysis and should be replaced with
a property of the dataset stored in data_reader.
Until then, assumes regular interval of changed model runs
"""
if not hasattr(self.config.dataloader, "model_run_info"):
data_reader.model_run_ids = None
return data_reader

mr_start = np.datetime64(self.config.dataloader.model_run_info.start)
mr_len = self.config.dataloader.model_run_info.length # model run length in number of date indices
assert max(
self.relative_date_indices
) <= mr_len, f"Requested data length {max(self.relative_date_indices)} longer than model run length {mr_len}"

data_reader.model_run_ids = (
data_reader.dates - mr_start) // np.timedelta64(
mr_len * frequency_to_seconds(self.config.data.frequency), 's')
return data_reader

@cached_property
def grid_indices(self) -> type[BaseGridIndices]:
reader_group_size = self.config.dataloader.get("read_group_size", self.config.hardware.num_gpus_per_model)
grid_indices = instantiate(self.config.dataloader.grid_indices, reader_group_size=reader_group_size)
reader_group_size = self.config.dataloader.get(
"read_group_size", self.config.hardware.num_gpus_per_model)
grid_indices = instantiate(self.config.dataloader.grid_indices,
reader_group_size=reader_group_size)
grid_indices.setup(self.graph_data)
return grid_indices

Expand Down Expand Up @@ -123,13 +190,16 @@ def timeincrement(self) -> int:
@cached_property
def ds_train(self) -> NativeGridDataset:
return self._get_dataset(
open_dataset(OmegaConf.to_container(self.config.dataloader.training, resolve=True)),
open_dataset(
OmegaConf.to_container(self.config.dataloader.training,
resolve=True)),
label="train",
)

@cached_property
def ds_valid(self) -> NativeGridDataset:
r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))
r = max(self.rollout,
self.config.dataloader.get("validation_rollout", 1))

if not self.config.dataloader.training.end < self.config.dataloader.validation.start:
LOGGER.warning(
Expand All @@ -138,24 +208,26 @@ def ds_valid(self) -> NativeGridDataset:
self.config.dataloader.validation.start,
)
return self._get_dataset(
open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)),
open_dataset(
OmegaConf.to_container(self.config.dataloader.validation,
resolve=True)),
shuffle=False,
rollout=r,
#rollout=r, #NOTE: see the above
label="validation",
)

@cached_property
def ds_test(self) -> NativeGridDataset:
assert self.config.dataloader.training.end < self.config.dataloader.test.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
f"test start date {self.config.dataloader.test.start}"
)
f"test start date {self.config.dataloader.test.start}")
assert self.config.dataloader.validation.end < self.config.dataloader.test.start, (
f"Validation end date {self.config.dataloader.validation.end} is not before"
f"test start date {self.config.dataloader.test.start}"
)
f"test start date {self.config.dataloader.test.start}")
return self._get_dataset(
open_dataset(OmegaConf.to_container(self.config.dataloader.test, resolve=True)),
open_dataset(
OmegaConf.to_container(self.config.dataloader.test,
resolve=True)),
shuffle=False,
label="test",
)
Expand All @@ -164,12 +236,9 @@ def _get_dataset(
self,
data_reader: Callable,
shuffle: bool = True,
rollout: int = 1,
label: str = "generic",
) -> NativeGridDataset:

r = max(rollout, self.rollout)

data_reader = self.add_model_run_ids(data_reader) # NOTE: Temporary
# Compute effective batch size
effective_bs = (
self.config.dataloader.batch_size["training"]
Expand All @@ -178,16 +247,15 @@ def _get_dataset(
// self.config.hardware.num_gpus_per_model
)

return NativeGridDataset(
data = NativeGridDataset(
data_reader=data_reader,
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
relative_date_indices=self.relative_date_indices,
shuffle=shuffle,
grid_indices=self.grid_indices,
label=label,
effective_bs=effective_bs,
)
return data

def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader:
assert stage in {"training", "validation", "test"}
Expand Down
24 changes: 18 additions & 6 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
relative_date_indices: list = [0,1,2],
model_comm_group_rank: int = 0,
model_comm_group_id: int = 0,
model_comm_num_groups: int = 1,
shuffle: bool = True,
label: str = "generic",
effective_bs: int = 1,
Expand All @@ -59,6 +63,14 @@ def __init__(
time increment between samples, by default 1
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
relative_date_indices : list
list of time indices to load from the data relative to the current sample i in __iter__
model_comm_group_rank : int, optional
process rank in the torch.distributed group (important when running on multiple GPUs), by default 0
model_comm_group_id: int, optional
device group ID, default 0
model_comm_num_groups : int, optional
total number of device groups, by default 1
shuffle : bool, optional
Shuffle batches, by default True
label : str, optional
Expand Down Expand Up @@ -94,11 +106,12 @@ def __init__(
self.shuffle = shuffle

# Data dimensions
self.multi_step = multistep
assert self.multi_step > 0, "Multistep value must be greater than zero."
self.ensemble_dim: int = 2
self.ensemble_size = self.data.shape[self.ensemble_dim]

# relative index of dates to extract
self.relative_date_indices = relative_date_indices

@cached_property
def statistics(self) -> dict:
"""Return dataset statistics."""
Expand Down Expand Up @@ -136,7 +149,7 @@ def valid_date_indices(self) -> np.ndarray:
dataset length minus rollout minus additional multistep inputs
(if time_increment is 1).
"""
return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement)
return get_usable_indices(self.data.missing, len(self.data), np.array(self.relative_date_indices, dtype=np.int64), self.data.model_run_ids)

def set_comm_group_info(
self,
Expand Down Expand Up @@ -281,6 +294,7 @@ def __iter__(self) -> torch.Tensor:
grid_shard_indices = self.grid_indices.get_shard_indices(self.reader_group_rank)
x = self.data[start : end : self.timeincrement, :, :, :]
x = x[..., grid_shard_indices] # select the grid shard
x = self.data[self.relative_date_indices + i] #NOTE: this requires an update to anemoi datasets
x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

Expand All @@ -290,9 +304,7 @@ def __repr__(self) -> str:
return f"""
{super().__repr__()}
Dataset: {self.data}
Rollout: {self.rollout}
Multistep: {self.multi_step}
Timeincrement: {self.timeincrement}
Relative dates: {self.relative_date_indices}
"""


Expand Down
122 changes: 122 additions & 0 deletions src/anemoi/training/train/interpolator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging
import math
import os
from collections import defaultdict
from collections.abc import Generator
from collections.abc import Mapping
from typing import Optional
from typing import Union
from operator import itemgetter

import numpy as np
import pytorch_lightning as pl
import torch
from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.interface import AnemoiModelInterface
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import OmegaConf
from timm.scheduler import CosineLRScheduler
from torch.distributed.distributed_c10d import ProcessGroup
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.training.losses.utils import grad_scaler
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.masks import Boolean1DMask
from anemoi.training.utils.masks import NoOutputMask

from anemoi.training.train.forecaster import GraphForecaster

LOGGER = logging.getLogger(__name__)

class GraphInterpolator(GraphForecaster):
"""Graph neural network interpolator for PyTorch Lightning."""

def __init__(
self,
*,
config: DictConfig,
graph_data: HeteroData,
statistics: dict,
data_indices: IndexCollection,
metadata: dict,
) -> None:
"""Initialize graph neural network interpolator.
Parameters
----------
config : DictConfig
Job configuration
graph_data : HeteroData
Graph object
statistics : dict
Statistics of the training data
data_indices : IndexCollection
Indices of the training data,
metadata : dict
Provenance information
"""
super().__init__(config = config, graph_data = graph_data, statistics = statistics, data_indices = data_indices, metadata = metadata)
self.target_forcing_indices = itemgetter(*config.training.target_forcing.data)(data_indices.data.input.name_to_index)
if type(self.target_forcing_indices) == int:
self.target_forcing_indices = [self.target_forcing_indices]
self.boundary_times = config.training.explicit_times.input
self.interp_times = config.training.explicit_times.target
sorted_indices = sorted(set(self.boundary_times + self.interp_times))
self.imap = {data_index: batch_index for batch_index,data_index in enumerate(sorted_indices)}


def _step(
self,
batch: torch.Tensor,
batch_idx: int,
validation_mode: bool = False,
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]:

del batch_idx
loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False)
metrics = {}
y_preds = []

batch = self.model.pre_processors(batch)
x_bound = batch[:, itemgetter(*self.boundary_times)(self.imap)][..., self.data_indices.data.input.full] # (bs, time, ens, latlon, nvar)

tfi = self.target_forcing_indices
target_forcing = torch.empty(batch.shape[0], batch.shape[2], batch.shape[3], len(tfi)+1, device = self.device, dtype = batch.dtype)
for interp_step in self.interp_times:
#get the forcing information for the target interpolation time:
target_forcing[..., :len(tfi)] = batch[:, self.imap[interp_step], :, :, tfi]
target_forcing[..., -1] = (interp_step - self.boundary_times[1])/(self.boundary_times[1] - self.boundary_times[0])
#TODO: make fraction time one of a config given set of arbitrary custom forcing functions.

y_pred = self(x_bound, target_forcing)
y = batch[:, self.imap[interp_step], :, :, self.data_indices.data.output.full]

loss += checkpoint(self.loss, y_pred, y, use_reentrant=False)

metrics_next = {}
if validation_mode:
metrics_next = self.calculate_val_metrics(y_pred, y, interp_step-1) #expects rollout but can be repurposed here.
metrics.update(metrics_next)
y_preds.extend(y_pred)

loss *= 1.0 / len(self.interp_times)
return loss, metrics, y_preds

def forward(self, x: torch.Tensor, target_forcing: torch.Tensor) -> torch.Tensor:
return self.model(x, target_forcing, self.model_comm_group)
Loading

0 comments on commit 6aa4d1b

Please sign in to comment.