diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index e0502acd..d3ae10bc 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -13,7 +13,8 @@ from functools import cached_property from typing import TYPE_CHECKING from typing import Callable - +import numpy as np +import os import pytorch_lightning as pl from anemoi.datasets.data import open_dataset from anemoi.models.data_indices.collection import IndexCollection @@ -49,14 +50,35 @@ def __init__(self, config: DictConfig, graph_data: HeteroData) -> None: super().__init__() self.config = config + + self.global_rank = int(os.environ.get("SLURM_PROCID", + "0")) # global rank + self.model_comm_group_id = ( + self.global_rank // self.config.hardware.num_gpus_per_model + ) # id of the model communication group the rank is participating in + self.model_comm_group_rank = ( + self.global_rank % self.config.hardware.num_gpus_per_model + ) # rank within one model communication group + total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes + assert (total_gpus) % self.config.hardware.num_gpus_per_model == 0, ( + f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}" + ) + self.model_comm_num_groups = (self.config.hardware.num_gpus_per_node * + self.config.hardware.num_nodes // + self.config.hardware.num_gpus_per_model + ) # number of model communication groups + LOGGER.debug( + "Rank %d model communication group number %d, with local model communication group rank %d", + self.global_rank, + self.model_comm_group_id, + self.model_comm_group_rank, + ) self.graph_data = graph_data # Set the maximum rollout to be expected - self.rollout = ( - self.config.training.rollout.max - if self.config.training.rollout.epoch_increment > 0 - else self.config.training.rollout.start - ) + self.rollout = (self.config.training.rollout.max + if self.config.training.rollout.epoch_increment > 0 + else self.config.training.rollout.start) # Set the training end date if not specified if self.config.dataloader.training.end is None: @@ -85,10 +107,55 @@ def supporting_arrays(self) -> dict: def data_indices(self) -> IndexCollection: return IndexCollection(self.config, self.ds_train.name_to_index) + @cached_property + def relative_date_indices(self) -> list: + """Determine a list of relative time indices to load for each batch""" + if hasattr(self.config.training, "explicit_times"): + return sorted( + set(self.config.training.explicit_times.input + + self.config.training.explicit_times.target)) + + else: #uses the old default of multistep, timeincrement and rollout. + # Use the maximum rollout to be expected + rollout = ( + self.config.training.rollout.max + if self.config.training.rollout.epoch_increment > 0 else + self.config.training.rollout.start + ) #NOTE: --> for gradual rollout, max rollout dates is always fetched. But this was always the case in datamodule.py + + multi_step = self.config.training.multistep_input + return [ + self.timeincrement * mstep + for mstep in range(multi_step + rollout) + ] + + def add_model_run_ids(self, data_reader): + """Determine the model run id of each time index of the data and add to a data_reader object + NOTE/TODO: This is only relevant when training on non-analysis and should be replaced with + a property of the dataset stored in data_reader. + Until then, assumes regular interval of changed model runs + """ + if not hasattr(self.config.dataloader, "model_run_info"): + data_reader.model_run_ids = None + return data_reader + + mr_start = np.datetime64(self.config.dataloader.model_run_info.start) + mr_len = self.config.dataloader.model_run_info.length # model run length in number of date indices + assert max( + self.relative_date_indices + ) <= mr_len, f"Requested data length {max(self.relative_date_indices)} longer than model run length {mr_len}" + + data_reader.model_run_ids = ( + data_reader.dates - mr_start) // np.timedelta64( + mr_len * frequency_to_seconds(self.config.data.frequency), 's') + return data_reader + @cached_property def grid_indices(self) -> type[BaseGridIndices]: - reader_group_size = self.config.dataloader.get("read_group_size", self.config.hardware.num_gpus_per_model) - grid_indices = instantiate(self.config.dataloader.grid_indices, reader_group_size=reader_group_size) + reader_group_size = self.config.dataloader.get( + "read_group_size", self.config.hardware.num_gpus_per_model) + grid_indices = instantiate(self.config.dataloader.grid_indices, + reader_group_size=reader_group_size) grid_indices.setup(self.graph_data) return grid_indices @@ -123,13 +190,16 @@ def timeincrement(self) -> int: @cached_property def ds_train(self) -> NativeGridDataset: return self._get_dataset( - open_dataset(OmegaConf.to_container(self.config.dataloader.training, resolve=True)), + open_dataset( + OmegaConf.to_container(self.config.dataloader.training, + resolve=True)), label="train", ) @cached_property def ds_valid(self) -> NativeGridDataset: - r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1)) + r = max(self.rollout, + self.config.dataloader.get("validation_rollout", 1)) if not self.config.dataloader.training.end < self.config.dataloader.validation.start: LOGGER.warning( @@ -138,9 +208,11 @@ def ds_valid(self) -> NativeGridDataset: self.config.dataloader.validation.start, ) return self._get_dataset( - open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)), + open_dataset( + OmegaConf.to_container(self.config.dataloader.validation, + resolve=True)), shuffle=False, - rollout=r, + #rollout=r, #NOTE: see the above label="validation", ) @@ -148,14 +220,14 @@ def ds_valid(self) -> NativeGridDataset: def ds_test(self) -> NativeGridDataset: assert self.config.dataloader.training.end < self.config.dataloader.test.start, ( f"Training end date {self.config.dataloader.training.end} is not before" - f"test start date {self.config.dataloader.test.start}" - ) + f"test start date {self.config.dataloader.test.start}") assert self.config.dataloader.validation.end < self.config.dataloader.test.start, ( f"Validation end date {self.config.dataloader.validation.end} is not before" - f"test start date {self.config.dataloader.test.start}" - ) + f"test start date {self.config.dataloader.test.start}") return self._get_dataset( - open_dataset(OmegaConf.to_container(self.config.dataloader.test, resolve=True)), + open_dataset( + OmegaConf.to_container(self.config.dataloader.test, + resolve=True)), shuffle=False, label="test", ) @@ -164,12 +236,9 @@ def _get_dataset( self, data_reader: Callable, shuffle: bool = True, - rollout: int = 1, label: str = "generic", ) -> NativeGridDataset: - - r = max(rollout, self.rollout) - + data_reader = self.add_model_run_ids(data_reader) # NOTE: Temporary # Compute effective batch size effective_bs = ( self.config.dataloader.batch_size["training"] @@ -178,16 +247,15 @@ def _get_dataset( // self.config.hardware.num_gpus_per_model ) - return NativeGridDataset( + data = NativeGridDataset( data_reader=data_reader, - rollout=r, - multistep=self.config.training.multistep_input, - timeincrement=self.timeincrement, + relative_date_indices=self.relative_date_indices, shuffle=shuffle, grid_indices=self.grid_indices, label=label, effective_bs=effective_bs, ) + return data def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader: assert stage in {"training", "validation", "test"} diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 431f0227..69b7c23f 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -41,6 +41,10 @@ def __init__( rollout: int = 1, multistep: int = 1, timeincrement: int = 1, + relative_date_indices: list = [0,1,2], + model_comm_group_rank: int = 0, + model_comm_group_id: int = 0, + model_comm_num_groups: int = 1, shuffle: bool = True, label: str = "generic", effective_bs: int = 1, @@ -59,6 +63,14 @@ def __init__( time increment between samples, by default 1 multistep : int, optional collate (t-1, ... t - multistep) into the input state vector, by default 1 + relative_date_indices : list + list of time indices to load from the data relative to the current sample i in __iter__ + model_comm_group_rank : int, optional + process rank in the torch.distributed group (important when running on multiple GPUs), by default 0 + model_comm_group_id: int, optional + device group ID, default 0 + model_comm_num_groups : int, optional + total number of device groups, by default 1 shuffle : bool, optional Shuffle batches, by default True label : str, optional @@ -94,11 +106,12 @@ def __init__( self.shuffle = shuffle # Data dimensions - self.multi_step = multistep - assert self.multi_step > 0, "Multistep value must be greater than zero." self.ensemble_dim: int = 2 self.ensemble_size = self.data.shape[self.ensemble_dim] + # relative index of dates to extract + self.relative_date_indices = relative_date_indices + @cached_property def statistics(self) -> dict: """Return dataset statistics.""" @@ -136,7 +149,7 @@ def valid_date_indices(self) -> np.ndarray: dataset length minus rollout minus additional multistep inputs (if time_increment is 1). """ - return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) + return get_usable_indices(self.data.missing, len(self.data), np.array(self.relative_date_indices, dtype=np.int64), self.data.model_run_ids) def set_comm_group_info( self, @@ -281,6 +294,7 @@ def __iter__(self) -> torch.Tensor: grid_shard_indices = self.grid_indices.get_shard_indices(self.reader_group_rank) x = self.data[start : end : self.timeincrement, :, :, :] x = x[..., grid_shard_indices] # select the grid shard + x = self.data[self.relative_date_indices + i] #NOTE: this requires an update to anemoi datasets x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") self.ensemble_dim = 1 @@ -290,9 +304,7 @@ def __repr__(self) -> str: return f""" {super().__repr__()} Dataset: {self.data} - Rollout: {self.rollout} - Multistep: {self.multi_step} - Timeincrement: {self.timeincrement} + Relative dates: {self.relative_date_indices} """ diff --git a/src/anemoi/training/train/interpolator.py b/src/anemoi/training/train/interpolator.py new file mode 100644 index 00000000..4c330d64 --- /dev/null +++ b/src/anemoi/training/train/interpolator.py @@ -0,0 +1,122 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +import math +import os +from collections import defaultdict +from collections.abc import Generator +from collections.abc import Mapping +from typing import Optional +from typing import Union +from operator import itemgetter + +import numpy as np +import pytorch_lightning as pl +import torch +from anemoi.models.data_indices.collection import IndexCollection +from anemoi.models.interface import AnemoiModelInterface +from anemoi.utils.config import DotDict +from hydra.utils import instantiate +from omegaconf import DictConfig +from omegaconf import OmegaConf +from timm.scheduler import CosineLRScheduler +from torch.distributed.distributed_c10d import ProcessGroup +from torch.distributed.optim import ZeroRedundancyOptimizer +from torch.utils.checkpoint import checkpoint +from torch_geometric.data import HeteroData + +from anemoi.training.losses.utils import grad_scaler +from anemoi.training.losses.weightedloss import BaseWeightedLoss +from anemoi.training.utils.jsonify import map_config_to_primitives +from anemoi.training.utils.masks import Boolean1DMask +from anemoi.training.utils.masks import NoOutputMask + +from anemoi.training.train.forecaster import GraphForecaster + +LOGGER = logging.getLogger(__name__) + +class GraphInterpolator(GraphForecaster): + """Graph neural network interpolator for PyTorch Lightning.""" + + def __init__( + self, + *, + config: DictConfig, + graph_data: HeteroData, + statistics: dict, + data_indices: IndexCollection, + metadata: dict, + ) -> None: + """Initialize graph neural network interpolator. + + Parameters + ---------- + config : DictConfig + Job configuration + graph_data : HeteroData + Graph object + statistics : dict + Statistics of the training data + data_indices : IndexCollection + Indices of the training data, + metadata : dict + Provenance information + + """ + super().__init__(config = config, graph_data = graph_data, statistics = statistics, data_indices = data_indices, metadata = metadata) + self.target_forcing_indices = itemgetter(*config.training.target_forcing.data)(data_indices.data.input.name_to_index) + if type(self.target_forcing_indices) == int: + self.target_forcing_indices = [self.target_forcing_indices] + self.boundary_times = config.training.explicit_times.input + self.interp_times = config.training.explicit_times.target + sorted_indices = sorted(set(self.boundary_times + self.interp_times)) + self.imap = {data_index: batch_index for batch_index,data_index in enumerate(sorted_indices)} + + + def _step( + self, + batch: torch.Tensor, + batch_idx: int, + validation_mode: bool = False, + ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: + + del batch_idx + loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) + metrics = {} + y_preds = [] + + batch = self.model.pre_processors(batch) + x_bound = batch[:, itemgetter(*self.boundary_times)(self.imap)][..., self.data_indices.data.input.full] # (bs, time, ens, latlon, nvar) + + tfi = self.target_forcing_indices + target_forcing = torch.empty(batch.shape[0], batch.shape[2], batch.shape[3], len(tfi)+1, device = self.device, dtype = batch.dtype) + for interp_step in self.interp_times: + #get the forcing information for the target interpolation time: + target_forcing[..., :len(tfi)] = batch[:, self.imap[interp_step], :, :, tfi] + target_forcing[..., -1] = (interp_step - self.boundary_times[1])/(self.boundary_times[1] - self.boundary_times[0]) + #TODO: make fraction time one of a config given set of arbitrary custom forcing functions. + + y_pred = self(x_bound, target_forcing) + y = batch[:, self.imap[interp_step], :, :, self.data_indices.data.output.full] + + loss += checkpoint(self.loss, y_pred, y, use_reentrant=False) + + metrics_next = {} + if validation_mode: + metrics_next = self.calculate_val_metrics(y_pred, y, interp_step-1) #expects rollout but can be repurposed here. + metrics.update(metrics_next) + y_preds.extend(y_pred) + + loss *= 1.0 / len(self.interp_times) + return loss, metrics, y_preds + + def forward(self, x: torch.Tensor, target_forcing: torch.Tensor) -> torch.Tensor: + return self.model(x, target_forcing, self.model_comm_group) \ No newline at end of file diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 694fb2da..8f6f9d46 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -15,6 +15,7 @@ from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING +import importlib import hydra import numpy as np @@ -142,7 +143,7 @@ def graph_data(self) -> HeteroData: ) @cached_property - def model(self) -> GraphForecaster: + def model(self) -> pl.LightningModule: """Provide the model instance.""" kwargs = { "config": self.config, @@ -152,9 +153,9 @@ def model(self) -> GraphForecaster: "statistics": self.datamodule.statistics, "supporting_arrays": self.supporting_arrays, } - - model = GraphForecaster(**kwargs) - + train_module = importlib.import_module(getattr(self.config.training, "train_module", "anemoi.training.train.forecaster")) + train_func = getattr(train_module, getattr(self.config.training, "train_function", "GraphForecaster")) + #NOTE: instantiate would be preferable, but I run into issues with "config" being the first kwarg of instantiate itself. if self.load_weights_only: # Sanify the checkpoint for transfer learning if self.config.training.transfer_learning: @@ -162,11 +163,8 @@ def model(self) -> GraphForecaster: return transfer_learning_loading(model, self.last_checkpoint) LOGGER.info("Restoring only model weights from %s", self.last_checkpoint) - - return GraphForecaster.load_from_checkpoint(self.last_checkpoint, **kwargs, strict=False) - - LOGGER.info("Model initialised from scratch.") - return model + return train_func.load_from_checkpoint(self.last_checkpoint, **kwargs) + return train_func(**kwargs) @rank_zero_only def _get_mlflow_run_id(self) -> str: diff --git a/src/anemoi/training/utils/usable_indices.py b/src/anemoi/training/utils/usable_indices.py index 0d97f25f..bef38090 100644 --- a/src/anemoi/training/utils/usable_indices.py +++ b/src/anemoi/training/utils/usable_indices.py @@ -14,13 +14,12 @@ def get_usable_indices( - missing_indices: set[int] | None, + missing_indices: set[int], series_length: int, - rollout: int, - multistep: int, - timeincrement: int = 1, + relative_indices: np.ndarray, + model_run_ids: np.ndarray | None = None, ) -> np.ndarray: - """Get the usable indices of a series whit missing indices. + """Get the usable indices of a series with missing indices. Parameters ---------- @@ -28,32 +27,28 @@ def get_usable_indices( Dataset to be used. series_length : int Length of the series. - rollout : int - Number of steps to roll out. - multistep : int - Number of previous indices to include as predictors. - timeincrement : int - Time increment, by default 1. + relative_indices: array[np.int64] + Array of relative indices requested at each index i. + model_run_ids: array[np.int64] + Array of integers of length series length that contains the id of a model run. + When training on analysis: None Returns ------- usable_indices : np.array Array of usable indices. """ - prev_invalid_dates = (multistep - 1) * timeincrement - next_invalid_dates = rollout * timeincrement + usable_indices = np.arange(series_length - max(relative_indices)) - usable_indices = np.arange(series_length) # set of all indices - - if missing_indices is None: - missing_indices = set() - - missing_indices |= {-1, series_length} # to filter initial and final indices + # Avoid crossing model runs by selecting only relative indices with the same model run id + if model_run_ids is not None: + rel_run = usable_indices[None] + relative_indices[:, None] + include = (model_run_ids[rel_run] == model_run_ids[rel_run[0]]).all(axis=0) + usable_indices = usable_indices[include] # Missing indices for i in missing_indices: - usable_indices = usable_indices[ - (usable_indices < i - next_invalid_dates) + (usable_indices > i + prev_invalid_dates) - ] + rel_missing = i - relative_indices #indices which have their relative indices match the missing. + usable_indices = usable_indices[np.all(usable_indices != rel_missing[:,np.newaxis], axis = 0)] return usable_indices diff --git a/tests/utils/test_usable_indices.py b/tests/utils/test_usable_indices.py index 0aff358a..5dcd671e 100644 --- a/tests/utils/test_usable_indices.py +++ b/tests/utils/test_usable_indices.py @@ -16,33 +16,47 @@ def test_get_usable_indices() -> None: """Test get_usable_indices function.""" # Test base case - valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=1, multistep=1, timeincrement=1) + valid_indices = get_usable_indices(missing_indices=set(), series_length=10, relative_indices = np.array([0, 1])) expected_values = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]) assert np.allclose(valid_indices, expected_values) - # Test multiple steps inputs - valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=1, multistep=2, timeincrement=1) - expected_values = np.array([1, 2, 3, 4, 5, 6, 7, 8]) - assert np.allclose(valid_indices, expected_values) - - # Test roll out - valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=2, multistep=1, timeincrement=1) + # Test 3 indices, either from rollout = 1 and multistep = 2, or rollout = 2 and multistep = 1 + valid_indices = get_usable_indices(missing_indices=set(), series_length=10, relative_indices = np.array([0, 1, 2])) expected_values = np.array([0, 1, 2, 3, 4, 5, 6, 7]) assert np.allclose(valid_indices, expected_values) - # Test longer time increments - valid_indices = get_usable_indices(missing_indices=None, series_length=10, rollout=1, multistep=2, timeincrement=2) - expected_values = np.array([2, 3, 4, 5, 6, 7]) + # With time increment + valid_indices = get_usable_indices(missing_indices=set(), series_length=10, relative_indices = np.array([0, 2, 4])) + expected_values = np.array([0, 1, 2, 3, 4, 5]) assert np.allclose(valid_indices, expected_values) - # Test missing indices - missing_indices = {7, 5} + # Test missing indices with standard setup + missing_indices = {7, 5, 14} valid_indices = get_usable_indices( missing_indices=missing_indices, - series_length=10, - rollout=1, - multistep=2, - timeincrement=1, + series_length=20, + relative_indices = np.array([0, 1, 2]) ) - expected_values = np.array([1, 2, 3]) + expected_values = np.array([0, 1, 2, 8, 9, 10, 11, 15, 16, 17]) assert np.allclose(valid_indices, expected_values) + + # Now verify that with a non-standard setup, missing indices can be "jumped" over. + valid_indices = get_usable_indices( + missing_indices = missing_indices, + series_length = 20, + relative_indices = np.array([0, 5, 6, 7]) #e.g making a 1 hour forecast based on -6, -1 and 0 hours. + ) + expected_values = np.array([3, 4, 6, 10, 11, 12]) + assert np.allclose(valid_indices, expected_values) + + # Test functionality for avoiding using different model runs + series_length = 70 + mr_start, mr_length = [4, 18] + model_run_ids = (np.arange(series_length, dtype = np.int64) - mr_start)//mr_length + valid_indices = get_usable_indices( + missing_indices = set(range(40, 58)) | {11}, #one model run of length 18 missing, and one sample of another run. + series_length = series_length, + relative_indices = np.array([0,3,6]), + model_run_ids = model_run_ids, + ) + expected_values = np.array([4,6,7,9,10,12,13,14,15] + list(range(22, 40-6)) + [58,59,60,61,62,63])