diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 1967fb82..fccde08d 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -13,7 +13,6 @@ import os import random from functools import cached_property -from typing import TYPE_CHECKING from typing import Callable import numpy as np @@ -27,9 +26,6 @@ LOGGER = logging.getLogger(__name__) -if TYPE_CHECKING: - from anemoi.training.data.grid_indices import BaseGridIndices - class NativeGridDataset(IterableDataset): """Iterable dataset for AnemoI data on the arbitrary grids.""" @@ -37,17 +33,12 @@ class NativeGridDataset(IterableDataset): def __init__( self, data_reader: Callable, - grid_indices: type[BaseGridIndices], - rollout: int = 1, - multistep: int = 1, - timeincrement: int = 1, - relative_date_indices: list = [0, 1, 2], + relative_date_indices: list = [0,1,2], model_comm_group_rank: int = 0, model_comm_group_id: int = 0, model_comm_num_groups: int = 1, shuffle: bool = True, label: str = "generic", - effective_bs: int = 1, ) -> None: """Initialize (part of) the dataset state. @@ -55,14 +46,6 @@ def __init__( ---------- data_reader : Callable user function that opens and returns the zarr array data - grid_indices : Type[BaseGridIndices] - indices of the grid to keep. Defaults to None, which keeps all spatial indices. - rollout : int, optional - length of rollout window, by default 12 - timeincrement : int, optional - time increment between samples, by default 1 - multistep : int, optional - collate (t-1, ... t - multistep) into the input state vector, by default 1 relative_date_indices : list list of time indices to load from the data relative to the current sample i in __iter__ model_comm_group_rank : int, optional @@ -75,30 +58,21 @@ def __init__( Shuffle batches, by default True label : str, optional label for the dataset, by default "generic" - effective_bs : int, default 1 - effective batch size useful to compute the lenght of the dataset + """ self.label = label - self.effective_bs = effective_bs self.data = data_reader - self.rollout = rollout - self.timeincrement = timeincrement - self.grid_indices = grid_indices - # lazy init self.n_samples_per_epoch_total: int = 0 self.n_samples_per_epoch_per_worker: int = 0 - # lazy init model and reader group info, will be set by the DDPGroupStrategy: - self.model_comm_group_rank = 0 - self.model_comm_num_groups = 1 - self.model_comm_group_id = 0 - self.global_rank = 0 - - self.reader_group_rank = 0 - self.reader_group_size = 1 + # DDP-relevant info + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.model_comm_group_id = model_comm_group_id + self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # additional state vars (lazy init) self.n_samples_per_worker = 0 @@ -122,11 +96,6 @@ def metadata(self) -> dict: """Return dataset metadata.""" return self.data.metadata() - @cached_property - def supporting_arrays(self) -> dict: - """Return dataset supporting_arrays.""" - return self.data.supporting_arrays() - @cached_property def name_to_index(self) -> dict: """Return dataset statistics.""" @@ -149,57 +118,7 @@ def valid_date_indices(self) -> np.ndarray: dataset length minus rollout minus additional multistep inputs (if time_increment is 1). """ - return get_usable_indices( - self.data.missing, - len(self.data), - np.array(self.relative_date_indices, dtype=np.int64), - self.data.model_run_ids, - ) - - def set_comm_group_info( - self, - global_rank: int, - model_comm_group_id: int, - model_comm_group_rank: int, - model_comm_num_groups: int, - reader_group_rank: int, - reader_group_size: int, - ) -> None: - """Set model and reader communication group information (called by DDPGroupStrategy). - - Parameters - ---------- - global_rank : int - Global rank - model_comm_group_id : int - Model communication group ID - model_comm_group_rank : int - Model communication group rank - model_comm_num_groups : int - Number of model communication groups - reader_group_rank : int - Reader group rank - reader_group_size : int - Reader group size - """ - self.global_rank = global_rank - self.model_comm_group_id = model_comm_group_id - self.model_comm_group_rank = model_comm_group_rank - self.model_comm_num_groups = model_comm_num_groups - self.reader_group_rank = reader_group_rank - self.reader_group_size = reader_group_size - - assert self.reader_group_size >= 1, "reader_group_size must be positive" - - LOGGER.debug( - "NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, " - "model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d", - global_rank, - model_comm_group_id, - model_comm_group_rank, - model_comm_num_groups, - reader_group_rank, - ) + return get_usable_indices(self.data.missing, len(self.data), np.array(self.relative_date_indices, dtype=np.int64), self.data.model_run_ids) def per_worker_init(self, n_workers: int, worker_id: int) -> None: """Called by worker_init_func on each copy of dataset. @@ -226,7 +145,6 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: low = shard_start + worker_id * self.n_samples_per_worker high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) - self.chunk_index_range = np.arange(low, high, dtype=np.uint32) LOGGER.debug( "Worker %d (pid %d, global_rank %d, model comm group %d) has low/high range %d / %d", @@ -238,17 +156,27 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: high, ) + self.chunk_index_range = self.valid_date_indices[np.arange(low, high, dtype=np.uint32)] + + # each worker must have a different seed for its random number generator, + # otherwise all the workers will output exactly the same data + # should we check lightning env variable "PL_SEED_WORKERS" here? + # but we alwyas want to seed these anyways ... + base_seed = get_base_seed() - torch.manual_seed(base_seed) - random.seed(base_seed) - self.rng = np.random.default_rng(seed=base_seed) + seed = ( + base_seed * (self.model_comm_group_id + 1) - worker_id + ) # note that test, validation etc. datasets get same seed + torch.manual_seed(seed) + random.seed(seed) + self.rng = np.random.default_rng(seed=seed) sanity_rnd = self.rng.random(1) LOGGER.debug( ( "Worker %d (%s, pid %d, glob. rank %d, model comm group %d, " - "group_rank %d, base_seed %d), sanity rnd %f" + "group_rank %d, base_seed %d) using seed %d, sanity rnd %f" ), worker_id, self.label, @@ -257,6 +185,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None: self.model_comm_group_id, self.model_comm_group_rank, base_seed, + seed, sanity_rnd, ) @@ -271,12 +200,12 @@ def __iter__(self) -> torch.Tensor: """ if self.shuffle: shuffled_chunk_indices = self.rng.choice( - self.valid_date_indices, - size=len(self.valid_date_indices), + self.chunk_index_range, + size=self.n_samples_per_worker, replace=False, - )[self.chunk_index_range] + ) else: - shuffled_chunk_indices = self.valid_date_indices[self.chunk_index_range] + shuffled_chunk_indices = self.chunk_index_range LOGGER.debug( ( @@ -293,9 +222,7 @@ def __iter__(self) -> torch.Tensor: ) for i in shuffled_chunk_indices: - grid_shard_indices = self.grid_indices.get_shard_indices(self.reader_group_rank) - x = x[..., grid_shard_indices] # select the grid shard - x = self.data[self.relative_date_indices + i] # NOTE: this requires an update to anemoi datasets + x = self.data[self.relative_date_indices + i] #NOTE: this requires an update to anemoi datasets x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") self.ensemble_dim = 1 @@ -333,4 +260,4 @@ def worker_init_func(worker_id: int) -> None: dataset_obj.per_worker_init( n_workers=worker_info.num_workers, worker_id=worker_id, - ) + ) \ No newline at end of file diff --git a/src/anemoi/training/train/interpolator.py b/src/anemoi/training/train/interpolator.py index 4c330d64..66ecf0ca 100644 --- a/src/anemoi/training/train/interpolator.py +++ b/src/anemoi/training/train/interpolator.py @@ -9,35 +9,15 @@ import logging -import math -import os -from collections import defaultdict -from collections.abc import Generator from collections.abc import Mapping -from typing import Optional -from typing import Union from operator import itemgetter -import numpy as np -import pytorch_lightning as pl import torch from anemoi.models.data_indices.collection import IndexCollection -from anemoi.models.interface import AnemoiModelInterface -from anemoi.utils.config import DotDict -from hydra.utils import instantiate from omegaconf import DictConfig -from omegaconf import OmegaConf -from timm.scheduler import CosineLRScheduler -from torch.distributed.distributed_c10d import ProcessGroup -from torch.distributed.optim import ZeroRedundancyOptimizer from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData -from anemoi.training.losses.utils import grad_scaler -from anemoi.training.losses.weightedloss import BaseWeightedLoss -from anemoi.training.utils.jsonify import map_config_to_primitives -from anemoi.training.utils.masks import Boolean1DMask -from anemoi.training.utils.masks import NoOutputMask from anemoi.training.train.forecaster import GraphForecaster @@ -54,6 +34,7 @@ def __init__( statistics: dict, data_indices: IndexCollection, metadata: dict, + supporting_arrays: dict ) -> None: """Initialize graph neural network interpolator. @@ -71,7 +52,7 @@ def __init__( Provenance information """ - super().__init__(config = config, graph_data = graph_data, statistics = statistics, data_indices = data_indices, metadata = metadata) + super().__init__(config = config, graph_data = graph_data, statistics = statistics, data_indices = data_indices, metadata = metadata, supporting_arrays=supporting_arrays) self.target_forcing_indices = itemgetter(*config.training.target_forcing.data)(data_indices.data.input.name_to_index) if type(self.target_forcing_indices) == int: self.target_forcing_indices = [self.target_forcing_indices] diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index a78ef524..3977567b 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -16,8 +16,6 @@ import torch.nn as nn from anemoi.utils.checkpoints import save_metadata -from anemoi.training.train.forecaster import GraphForecaster - LOGGER = logging.getLogger(__name__) @@ -35,6 +33,8 @@ def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Mod pytorch model, metadata """ + from anemoi.training.train.forecaster import GraphForecaster + module = GraphForecaster.load_from_checkpoint(lightning_checkpoint_path) model = module.model