diff --git a/bris/__main__.py b/bris/__main__.py index a0243ac..c30328e 100644 --- a/bris/__main__.py +++ b/bris/__main__.py @@ -2,6 +2,7 @@ from argparse import ArgumentParser import numpy as np import os +from datetime import datetime, timedelta from hydra.utils import instantiate @@ -12,7 +13,6 @@ from .checkpoint import Checkpoint from .inference import Inference -from .predict_metadata import PredictMetadata from .utils import create_config from .writer import CustomWriter @@ -36,24 +36,57 @@ def main(): LOGGER.info("Update graph is enabled. Proceeding to change internal graph") checkpoint.update_graph(config.model.graph) # Pass in a new graph if needed + # Get timestep from checkpoint. Also store a version in seconds for local use. + config.timestep = None + try: + config.timestep = checkpoint.config.data.timestep + except KeyError: + raise RuntimeError("Error getting timestep from checkpoint (checkpoint.config.data.timestep)") + timestep_seconds = frequency_to_seconds(config.timestep) + + num_members = 1 + + # Get multistep. A default of 2 to ignore multistep in start_date calculation if not set. + multistep = 2 + try: + multistep = checkpoint.config.training.multistep_input + except KeyError: + LOGGER.debug("Multistep not found in checkpoint") + + # If no start_date given, calculate as end_date-((multistep-1)*timestep) + if "start_date" not in config: + config.start_date = datetime.strftime( + datetime.strptime(config.end_date, "%Y-%m-%dT%H:%M:%S") - timedelta(seconds=(multistep - 1) * timestep_seconds), + "%Y-%m-%dT%H:%M:%S" + ) + LOGGER.info("No start_date given, setting %s based on start_date and timestep.", config.start_date) + else: + config.start_date = datetime.strftime( + datetime.strptime(config.start_date, "%Y-%m-%dT%H:%M:%S") - timedelta(seconds=(multistep - 1) * timestep_seconds), + "%Y-%m-%dT%H:%M:%S" + ) + + config.dataset = {"dataset": config.dataset, "start": config.start_date, "end": config.end_date, "frequency": config.frequency} + datamodule = DataModule( config=config, checkpoint_object=checkpoint, ) - - # Assemble outputs - workdir = config.hardware.paths.workdir - num_members = 1 # Get outputs and required_variables of each decoder - timestep = frequency_to_seconds(config.timestep) - leadtimes = np.arange(config.leadtimes) * timestep + leadtimes = np.arange(config.leadtimes) * timestep_seconds decoder_outputs = bris.routes.get( - config["routing"], leadtimes, num_members, datamodule, workdir + config["routing"], leadtimes, num_members, datamodule, config.workdir ) required_variables = bris.routes.get_required_variables(config["routing"], datamodule) writer = CustomWriter(decoder_outputs, write_interval="batch") + # Set hydra defaults + config.defaults = [ + {'override hydra/job_logging': 'none'}, # disable config parsing logs + {'override hydra/hydra_logging': 'none'}, # disable config parsing logs + '_self_'] + # Forecaster must know about what leadtimes to output model = instantiate( config.model, @@ -63,7 +96,6 @@ def main(): forecast_length=config.leadtimes, required_variables=required_variables, release_cache=config.release_cache, - ) callbacks = list() diff --git a/bris/data/datamodule.py b/bris/data/datamodule.py index 3500af6..cf20257 100644 --- a/bris/data/datamodule.py +++ b/bris/data/datamodule.py @@ -1,7 +1,6 @@ import logging -import os from functools import cached_property -from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING import numpy as np import pytorch_lightning as pl @@ -9,15 +8,18 @@ from anemoi.utils.config import DotDict from anemoi.utils.dates import frequency_to_seconds from hydra.utils import instantiate + from omegaconf import DictConfig, OmegaConf, errors -from torch.utils.data import DataLoader, get_worker_info -from torch_geometric.data import HeteroData +from torch.utils.data import DataLoader, IterableDataset, get_worker_info import anemoi.datasets.data.subset import anemoi.datasets.data.select from bris.checkpoint import Checkpoint -from bris.data.dataset import Dataset -from bris.utils import check_anemoi_training, recursive_list_to_tuple +from bris.data.dataset import worker_init_func +from bris.data.grid_indices import FullGrid +from bris.utils import recursive_list_to_tuple + +from bris.data.grid_indices import BaseGridIndices LOGGER = logging.getLogger(__name__) @@ -47,8 +49,7 @@ def __init__( self.ckptObj = checkpoint_object self.timestep = config.timestep self.frequency = config.frequency - self.legacy = not check_anemoi_training(metadata=self.ckptObj._metadata) - + def predict_dataloader(self) -> DataLoader: """ Creates a dataloader for prediction @@ -57,34 +58,19 @@ def predict_dataloader(self) -> DataLoader: None return: - """ - return self._get_dataloader(self.ds_predict) - - def _get_dataloader(self, ds): - """ - Creates torch dataloader object for - ds. Batch_size, num_workers, prefetch_factor - and pin_memory can be adjusted in the config - under dataloader. - - args: - ds: anemoi.datasets.data.open_dataset object - - return: - torch dataloader initialized on anemoi dataset object """ return DataLoader( - ds, - batch_size=self.config.dataloader.batch_size, + self.ds_predict, + batch_size=1, # number of worker processes - num_workers=self.config.dataloader.num_workers, + num_workers=self.config.dataloader.get("num_workers", 1), # use of pinned memory can speed up CPU-to-GPU data transfers # see https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-pinning pin_memory=self.config.dataloader.get("pin_memory", True), # worker initializer worker_init_fn=worker_init_func, # prefetch batches - prefetch_factor=self.config.dataloader.prefetch_factor, + prefetch_factor=self.config.dataloader.get("prefetch_factor", 2), persistent_workers=True, ) @@ -99,100 +85,22 @@ def ds_predict(self) -> Any: Anemoi dataset open_dataset object """ return self._get_dataset(self.data_reader) - + def _get_dataset( self, data_reader, - ): - """ - Instantiates a given dataset class - from anemoi.training.data.dataset. - This assumes that the python path for - the class is defined, and anemoi-training - for a given branch is installed with pip - in order to access the class. This - method returns an instantiated instance of - a given data class. This supports - data distributed parallel (DDP) and model - sharding. - - args: - data_reader: anemoi open_dataset object - - return: - an dataset class object - """ - if self.legacy: - # TODO: fix imports and pip packages for legacy version - LOGGER.info( - """Did not find anemoi.training version in checkpoint metadata, assuming - the model was trained with aifs-mono and using legacy functionality""" - ) - LOGGER.warning("WARNING! Ensemble legacy mode has yet to be implemented!") - from .legacy.dataset import EnsNativeGridDataset, NativeGridDataset - from .legacy.utils import _legacy_slurm_proc_id - - model_comm_group_rank, model_comm_group_id, model_comm_num_groups = ( - _legacy_slurm_proc_id(self.config) - ) - - spatial_mask = {} - for mesh_name, mesh in self.graph.items(): - if ( - isinstance(mesh_name, str) - and mesh_name - != self.ckptObj._metadata.config.graphs.hidden_mesh.name - ): - spatial_mask[mesh_name] = mesh.get("dataset_idx", None) - spatial_index = spatial_mask[ - self.ckptObj._metadata.config.graphs.encoders[0]["src_mesh"] - ] - - dataCls = NativeGridDataset( - data_reader=data_reader, - rollout=0, # we dont perform rollout during inference - multistep=self.ckptObj.multistep, - timeincrement=self.timeincrement, - model_comm_group_rank=model_comm_group_rank, - model_comm_group_id=model_comm_group_id, - model_comm_num_groups=model_comm_num_groups, - spatial_index=spatial_index, - shuffle=False, - label="predict", - ) - return dataCls - else: - try: - dataCls = instantiate( - config=self.config.dataloader.datamodule, - data_reader=data_reader, - rollout=0, # we dont perform rollout during inference - multistep=self.ckptObj.multistep, - timeincrement=self.timeincrement, - grid_indices=self.grid_indices, - shuffle=False, - label="predict", - ) - except: - dataCls = instantiate( - config=self.config.dataloader.datamodule, - data_reader=data_reader, - rollout=0, # we dont perform rollout during inference - multistep=self.ckptObj.multistep, - timeincrement=self.timeincrement, - shuffle=False, - label="predict", - ) + ) -> IterableDataset: + ds = instantiate( + config=self.config.dataloader.datamodule, + data_reader=data_reader, + rollout=0, + multistep=self.ckptObj.multistep, + timeincrement=self.timeincrement, + grid_indices=self.grid_indices, + label="predict", + ) - return Dataset(dataCls) - - @property - def name_to_index(self): - """ - Returns a tuple of dictionaries, where each dict is: - variable_name -> index - """ - return self.ckptObj.name_to_index + return ds @cached_property def data_reader(self): @@ -211,7 +119,7 @@ def data_reader(self): An anemoi open_dataset object """ base_loader = OmegaConf.to_container( - self.config.dataloader.predict, + self.config.dataset, resolve=True ) return open_dataset(base_loader) @@ -243,23 +151,30 @@ def timeincrement(self) -> int: timestep, ) return timestep // frequency - - @cached_property - def grid_indices(self): - if check_anemoi_training(self.ckptObj.metadata): - try: - from anemoi.training.data.grid_indices import BaseGridIndices, FullGrid - except ImportError as e: - print("Warning! Could not import BaseGridIndices and FullGrid. Continuing without this module") - reader_group_size = self.config.dataloader.read_group_size - grid_indices = FullGrid( - nodes_name="data", - reader_group_size=reader_group_size + @property + def name_to_index(self): + """ + Returns a tuple of dictionaries, where each dict is: + variable_name -> index + """ + return self.ckptObj.name_to_index + + @cached_property + def grid_indices(self) -> type[BaseGridIndices]: + reader_group_size = 1 + if hasattr(self.config.dataloader, "grid_indices"): + grid_indices = instantiate(self.config.dataloder.grid_indices, reader_group_size=reader_group_size) + LOGGER.info("Using grid indices from dataloader config") + else: + grid_indices = FullGrid( + nodes_name="data", + reader_group_size=reader_group_size ) + LOGGER.info("grid_indices not found in dataloader config, defaulting to FullGrid") grid_indices.setup(self.graph) return grid_indices - + @cached_property def grids(self) -> tuple: """ @@ -311,10 +226,11 @@ def altitudes(self) -> tuple: return altitudes - @cached_property def field_shape(self) -> tuple: - + """ + Retrieve field_shape of the datasets + """ field_shape = [None]*len(self.grids) for decoder_index, grids in enumerate(self.grids): field_shape[decoder_index] = [None]*len(grids) @@ -344,30 +260,3 @@ def _get_field_shape(self, decoder_index, dataset_index): assert (decoder_index == 0 and dataset_index == 0) return data_reader.field_shape -def worker_init_func(worker_id: int) -> None: - """Configures each dataset worker process. - - Calls WeatherBenchDataset.per_worker_init() on each dataset object. - - Parameters - ---------- - worker_id : int - Worker ID - - Raises - ------ - RuntimeError - If worker_info is None - - """ - worker_info = get_worker_info() # information specific to each worker process - if worker_info is None: - # LOGGER.error("worker_info is None! Set num_workers > 0 in your dataloader!") - raise RuntimeError - dataset_obj = ( - worker_info.dataset - ) # the copy of the dataset held by this worker process. - dataset_obj.per_worker_init( - n_workers=worker_info.num_workers, - worker_id=worker_id, - ) diff --git a/bris/data/dataset.py b/bris/data/dataset.py index 6bb49b8..48651ef 100644 --- a/bris/data/dataset.py +++ b/bris/data/dataset.py @@ -1,50 +1,105 @@ from typing import Any import torch +import numpy as np +import random +import logging + from numpy import datetime64 -from torch.utils.data import IterableDataset from einops import rearrange +from functools import cached_property +from typing import Callable, TYPE_CHECKING + +from torch.utils.data import IterableDataset +from torch.utils.data import get_worker_info +from bris.utils import get_usable_indices, get_base_seed + + +from bris.data.grid_indices import BaseGridIndices + + +LOGGER = logging.getLogger(__name__) +class NativeGridDataset(IterableDataset): + """Iterable dataset for AnemoI data on the arbitrary grids.""" -class Dataset(IterableDataset): def __init__( self, - dataCls: Any, - ): - """ - Wrapper for a given anemoi.training.data.dataset class - to include timestamp in the iterator. + data_reader: Callable, + grid_indices: type[BaseGridIndices], + rollout: int = 1, + multistep: int = 1, + timeincrement: int = 1, + label: str = "generic", + effective_bs: int = 1, + ) -> None: + """Initialize (part of) the dataset state. + + Parameters + ---------- + data_reader : Callable + user function that opens and returns the zarr array data + grid_indices : Type[BaseGridIndices] + indices of the grid to keep. Defaults to None, which keeps all spatial indices. + rollout : int, optional + length of rollout window, by default 12 + timeincrement : int, optional + time increment between samples, by default 1 + multistep : int, optional + collate (t-1, ... t - multistep) into the input state vector, by default 1 + shuffle : bool, optional + Shuffle batches, by default True + label : str, optional + label for the dataset, by default "generic" + effective_bs : int, default 1 + effective batch size useful to compute the lenght of the dataset """ + self.label = label + self.effective_bs = effective_bs - super().__init__() - if hasattr(dataCls, "data"): - self.data = dataCls.data - else: - raise RuntimeError("dataCls does not have attribute data") - self.dataCls = dataCls + self.data = data_reader - def per_worker_init(self, n_workers, worker_id): - """ - Delegate per_worker_init to the underlying dataset. - Called by worker_init_func on each copy of dataset. + self.rollout = rollout + self.timeincrement = timeincrement + self.grid_indices = grid_indices - This initialises after the worker process has been spawned. + # lazy init + self.n_samples_per_epoch_total: int = 0 + self.n_samples_per_epoch_per_worker: int = 0 - Parameters - ---------- - n_workers : int - Number of workers - worker_id : int - Worker ID + # lazy init model and reader group info, will be set by the DDPGroupStrategy: + self.model_comm_group_rank = 0 + self.model_comm_num_groups = 1 + self.model_comm_group_id = 0 + self.global_rank = 0 + + self.reader_group_rank = 0 + self.reader_group_size = 1 + + # additional state vars (lazy init) + self.n_samples_per_worker = 0 + self.chunk_index_range: np.ndarray | None = None + + # Data dimensions + self.multi_step = multistep + assert self.multi_step > 0, "Multistep value must be greater than zero." + self.ensemble_dim: int = 2 + self.ensemble_size = self.data.shape[self.ensemble_dim] + + @cached_property + def valid_date_indices(self) -> np.ndarray: + """Return valid date indices. + A date t is valid if we can sample the sequence + (t - multistep + 1, ..., t + rollout) + without missing data (if time_increment is 1). + + If there are no missing dates, total number of valid ICs is + dataset length minus rollout minus additional multistep inputs + (if time_increment is 1). """ - if hasattr(self.dataCls, "per_worker_init"): - self.dataCls.per_worker_init(n_workers=n_workers, worker_id=worker_id) - else: - raise RuntimeError( - "Warning: Underlying dataset does not implement 'per_worker_init'." - ) - + return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) + def set_comm_group_info( self, global_rank: int, @@ -54,24 +109,122 @@ def set_comm_group_info( reader_group_rank: int, reader_group_size: int, ) -> None: - if hasattr(self.dataCls, "set_comm_group_info"): - self.dataCls.set_comm_group_info( - global_rank=global_rank, - model_comm_group_id=model_comm_group_id, - model_comm_group_rank=model_comm_group_rank, - model_comm_num_groups=model_comm_num_groups, - reader_group_rank=reader_group_rank, - reader_group_size=reader_group_size, - ) - else: - raise RuntimeError( - "Warning: Underlying dataset does not implement 'set_comm_group_info'." - ) - - - def __iter__( - self, - ) -> tuple[torch.Tensor, datetime64] | tuple[tuple[torch.Tensor], datetime64]: + """Set model and reader communication group information (called by DDPGroupStrategy). + + Parameters + ---------- + global_rank : int + Global rank + model_comm_group_id : int + Model communication group ID + model_comm_group_rank : int + Model communication group rank + model_comm_num_groups : int + Number of model communication groups + reader_group_rank : int + Reader group rank + reader_group_size : int + Reader group size + """ + self.global_rank = global_rank + self.model_comm_group_id = model_comm_group_id + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.reader_group_rank = reader_group_rank + self.reader_group_size = reader_group_size + + assert self.reader_group_size >= 1, "reader_group_size must be positive" + + LOGGER.debug( + "NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, " + "model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d", + global_rank, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + reader_group_rank, + ) + + def per_worker_init(self, n_workers: int, worker_id: int) -> None: + """Called by worker_init_func on each copy of dataset. + + This initialises after the worker process has been spawned. + + Parameters + ---------- + n_workers : int + Number of workers + worker_id : int + Worker ID + + """ + self.worker_id = worker_id + + # Divide this equally across shards (one shard per group!) + shard_size = len(self.valid_date_indices) // self.model_comm_num_groups + shard_start = self.model_comm_group_id * shard_size + shard_end = (self.model_comm_group_id + 1) * shard_size + + shard_len = shard_end - shard_start + self.n_samples_per_worker = shard_len // n_workers + + low = shard_start + worker_id * self.n_samples_per_worker + high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end) + self.chunk_index_range = np.arange(low, high, dtype=np.uint32) + + base_seed = get_base_seed() + + torch.manual_seed(base_seed) + random.seed(base_seed) + self.rng = np.random.default_rng(seed=base_seed) + + def __iter__(self) -> torch.Tensor: + """Return an iterator over the dataset. + + The datasets are retrieved by Anemoi Datasets from zarr files. This iterator yields + chunked batches for DDP and sharded training. + + Currently it receives data with an ensemble dimension, which is discarded for + now. (Until the code is "ensemble native".) + """ + + shuffled_chunk_indices = self.valid_date_indices[self.chunk_index_range] + + for i in shuffled_chunk_indices: + start = i - (self.multi_step - 1) * self.timeincrement + end = i + (self.rollout + 1) * self.timeincrement + + grid_shard_indices = self.grid_indices.get_shard_indices(self.reader_group_rank) + x = self.data[start : end : self.timeincrement, :, :, :] + x = x[..., grid_shard_indices] # select the grid shard + x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") + self.ensemble_dim = 1 + + yield (torch.from_numpy(x), str(self.data.dates[i])) + + +def worker_init_func(worker_id: int) -> None: + """Configures each dataset worker process. + + Calls WeatherBenchDataset.per_worker_init() on each dataset object. + + Parameters + ---------- + worker_id : int + Worker ID + + Raises + ------ + RuntimeError + If worker_info is None - for idx, x in enumerate(iter(self.dataCls)): - yield (x, str(self.data.dates[self.dataCls.chunk_index_range[idx] + self.dataCls.multi_step -1])) + """ + worker_info = get_worker_info() # information specific to each worker process + if worker_info is None: + LOGGER.error("worker_info is None! Set num_workers > 0 in your dataloader!") + raise RuntimeError + dataset_obj = worker_info.dataset # the copy of the dataset held by this worker process. + dataset_obj.per_worker_init( + n_workers=worker_info.num_workers, + worker_id=worker_id, + ) \ No newline at end of file diff --git a/bris/data/grid_indices.py b/bris/data/grid_indices.py new file mode 100644 index 0000000..2bbc90c --- /dev/null +++ b/bris/data/grid_indices.py @@ -0,0 +1,100 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from abc import ABC +from abc import abstractmethod +from collections.abc import Sequence +from typing import TYPE_CHECKING +from typing import Union + +import numpy as np + +from torch_geometric.data import HeteroData + +LOGGER = logging.getLogger(__name__) + +ArrayIndex = Union[slice, int, Sequence[int]] + + +class BaseGridIndices(ABC): + """Base class for custom grid indices.""" + + def __init__(self, nodes_name: str, reader_group_size: int) -> None: + self.nodes_name = nodes_name + self.reader_group_size = reader_group_size + + def setup(self, graph: HeteroData) -> None: + self.grid_size = self.compute_grid_size(graph) + + def split_seq_in_shards(self, reader_group_rank: int) -> tuple[int, int]: + """Get the indices to split a sequence into equal size shards.""" + grid_shard_size = self.grid_size // self.reader_group_size + grid_start = reader_group_rank * grid_shard_size + if reader_group_rank == self.reader_group_size - 1: + grid_end = self.grid_size + else: + grid_end = (reader_group_rank + 1) * grid_shard_size + + return slice(grid_start, grid_end) + + @property + def supporting_arrays(self) -> dict: + return {} + + @abstractmethod + def compute_grid_size(self, graph: HeteroData) -> int: ... + + @abstractmethod + def get_shard_indices(self, reader_group_rank: int) -> ArrayIndex: ... + + +class FullGrid(BaseGridIndices): + """The full grid is loaded.""" + + def compute_grid_size(self, graph: HeteroData) -> int: + if hasattr(graph[self.nodes_name], "num_nodes"): + return graph[self.nodes_name].num_nodes + elif hasattr(graph[self.nodes_name], "coords"): + return graph[self.nodes_name]["coords"].shape[0] + else: + raise ValueError("Could not compute grid size in graph") + + def get_shard_indices(self, reader_group_rank: int) -> ArrayIndex: + return self.split_seq_in_shards(reader_group_rank) + + +class MaskedGrid(BaseGridIndices): + """Grid is masked based on a node attribute.""" + + def __init__(self, nodes_name: str, reader_group_size: int, node_attribute_name: str): + super().__init__(nodes_name, reader_group_size) + self.node_attribute_name = node_attribute_name + + def setup(self, graph: HeteroData) -> None: + LOGGER.info( + "The graph attribute %s of the %s nodes will be used to masking the spatial dimension.", + self.node_attribute_name, + self.nodes_name, + ) + self.grid_indices = graph[self.nodes_name][self.node_attribute_name].squeeze().tolist() + super().setup(graph) + + @property + def supporting_arrays(self) -> dict: + return {"grid_indices": np.array(self.grid_indices, dtype=np.int64)} + + def compute_grid_size(self, _graph: HeteroData) -> int: + return len(self.grid_indices) + + def get_shard_indices(self, reader_group_rank: int) -> ArrayIndex: + sequence_indices = self.split_seq_in_shards(reader_group_rank) + return self.grid_indices[sequence_indices] \ No newline at end of file diff --git a/bris/ddp_strategy.py b/bris/ddp_strategy.py new file mode 100644 index 0000000..0aa4a47 --- /dev/null +++ b/bris/ddp_strategy.py @@ -0,0 +1,255 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging + +import numpy as np +import pytorch_lightning as pl +import torch +from lightning_fabric.utilities.optimizer import _optimizers_to_device +from pytorch_lightning.overrides.distributed import _sync_module_states +from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning.trainer.states import TrainerFn + +from bris.utils import get_base_seed + +LOGGER = logging.getLogger(__name__) + + +class DDPGroupStrategy(DDPStrategy): + """Distributed Data Parallel strategy with group communication.""" + + def __init__(self, num_gpus_per_model: int, read_group_size: int, **kwargs: dict) -> None: + """Initialize the distributed strategy. + + Parameters + ---------- + num_gpus_per_model : int + Number of GPUs per model to shard over. + read_group_size : int + Number of GPUs per reader group. + **kwargs : dict + Additional keyword arguments. + + """ + super().__init__(**kwargs) + self.model_comm_group_size = num_gpus_per_model + self.read_group_size = read_group_size + + def setup(self, trainer: pl.Trainer) -> None: + assert self.accelerator is not None, "Accelerator is not initialized for distributed strategy" + self.accelerator.setup(trainer) + + # determine the model groups that work together: + + assert self.world_size % self.model_comm_group_size == 0, ( + f"Total number of GPUs ({self.world_size}) must be divisible by the number of GPUs " + f"per model ({self.model_comm_group_size})." + ) + + model_comm_group_ranks = np.split( + np.arange(self.world_size, dtype=int), + int(self.world_size / self.model_comm_group_size), + ) + model_comm_groups = [ + torch.distributed.new_group(x) for x in model_comm_group_ranks + ] # every rank has to create all of these + + model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( + self.model_comm_group_size, + ) + model_comm_group = model_comm_groups[model_comm_group_id] + self.model.set_model_comm_group( + model_comm_group, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + self.model_comm_group_size, + ) + + # set up reader groups by further splitting model_comm_group_ranks with read_group_size: + + assert self.model_comm_group_size % self.read_group_size == 0, ( + f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by read_group_size " + f"({self.read_group_size})." + ) + + reader_group_ranks = np.array( + [ + np.split(group_ranks, int(self.model_comm_group_size / self.read_group_size)) + for group_ranks in model_comm_group_ranks + ], + ) # Shape: (num_model_comm_groups, model_comm_grp_size/read_group_size, read_group_size) + reader_groups = [[torch.distributed.new_group(x) for x in group_ranks] for group_ranks in reader_group_ranks] + reader_group_id, reader_group_rank, reader_group_size, reader_group_root = self.get_my_reader_group( + model_comm_group_rank, + self.read_group_size, + ) + # get all reader groups of the current model group + model_reader_groups = reader_groups[model_comm_group_id] + self.model.set_reader_groups( + model_reader_groups, + reader_group_id, + reader_group_rank, + reader_group_size, + ) + + LOGGER.debug( + "Rank %d model_comm_group_id: %d model_comm_group: %s model_comm_group_rank: %d " + "reader_group_id: %d reader_group: %s reader_group_rank: %d reader_group_root (global): %d", + self.global_rank, + model_comm_group_id, + str(model_comm_group_ranks[model_comm_group_id]), + model_comm_group_rank, + reader_group_id, + reader_group_ranks[model_comm_group_id, reader_group_id], + reader_group_rank, + reader_group_root, + ) + + # register hooks for correct gradient reduction + self.register_parameter_hooks() + + # move the model to the correct device + self.model_to_device() + + # skip wrapping the model if we are not fitting as no gradients need to be exchanged + trainer_fn = trainer.state.fn + + if trainer_fn == TrainerFn.FITTING and self._layer_sync: + assert self.model is not None, "Model is not initialized for distributed strategy" + self.model = self._layer_sync.apply(self.model) + + self.setup_precision_plugin() + + if trainer_fn == TrainerFn.FITTING: + # do not wrap with DDP if not fitting as there's no gradients to reduce + self.configure_ddp() + + # set up optimizers after the wrapped module has been moved to the device + self.setup_optimizers(trainer) + _optimizers_to_device(self.optimizers, self.root_device) + + import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD + + if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): + self._enable_model_averaging() + else: + # we need to manually synchronize the module's states since we aren't using the DDP wrapper + assert self.model is not None, "Model is not initialized for distributed strategy" + _sync_module_states(self.model) + + # seed ranks + self.seed_rnd(model_comm_group_id) + + def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, int, int]: + """Determine tasks that work together and from a model group. + + Parameters + ---------- + num_gpus_per_model : int + Number of GPUs per model to shard over. + + Returns + ------- + tuple[int, int, int] + Model_comm_group id, Model_comm_group rank, Number of model_comm_groups + """ + model_comm_group_id = self.global_rank // num_gpus_per_model + model_comm_group_rank = self.global_rank % num_gpus_per_model + model_comm_num_groups = self.world_size // num_gpus_per_model + + return model_comm_group_id, model_comm_group_rank, model_comm_num_groups + + def get_my_reader_group(self, model_comm_group_rank: int, read_group_size: int) -> tuple[int, int, int]: + """Determine tasks that work together and from a reader group. + + Parameters + ---------- + model_comm_group_rank : int + Rank within the model communication group. + read_group_size : int + Number of dataloader readers per model group. + + Returns + ------- + tuple[int, int, int] + Reader_group id, Reader_group rank, Reader_group root (global rank) + """ + reader_group_id = model_comm_group_rank // read_group_size + reader_group_rank = model_comm_group_rank % read_group_size + reader_group_size = read_group_size + reader_group_root = (self.global_rank // read_group_size) * read_group_size + + return reader_group_id, reader_group_rank, reader_group_size, reader_group_root + + def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> torch.utils.data.DataLoader: + """Pass communication group information to the dataloader for distributed training. + + Parameters + ---------- + dataloader : torch.utils.data.DataLoader + Dataloader to process. + + Returns + ------- + torch.utils.data.DataLoader + Processed dataloader. + + """ + dataloader = super().process_dataloader(dataloader) + + # pass model and reader group information to the dataloaders dataset + model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( + self.model_comm_group_size, + ) + _, reader_group_rank, _, _ = self.get_my_reader_group(model_comm_group_rank, self.read_group_size) + + dataloader.dataset.set_comm_group_info( + self.global_rank, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + reader_group_rank, + self.read_group_size, + ) + + return dataloader + + def seed_rnd(self, model_comm_group_id: int) -> None: + """Seed the random number generators for the rank.""" + base_seed = get_base_seed() + initial_seed = base_seed * (model_comm_group_id + 1) + rnd_seed = pl.seed_everything(initial_seed) # note: workers are seeded independently in dataloader + np_rng = np.random.default_rng(rnd_seed) + sanity_rnd = (torch.rand(1), np_rng.random()) + LOGGER.debug( + ( + "Strategy: Rank %d, model comm group id %d, base seed %d, seeded with %d, " + "running with random seed: %d, sanity rnd: %s" + ), + self.global_rank, + model_comm_group_id, + base_seed, + initial_seed, + rnd_seed, + sanity_rnd, + ) + + def register_parameter_hooks(self) -> None: + """Register parameter hooks for gradient reduction. + + Here, we rescale parameters that only see a subset of the input on each rank + -> these are still divided by the total number of GPUs in DDP as if each rank would see a full set of inputs + note: the trainable parameters are added before the split across GPUs and are therefore not rescaled. + """ + for name, param in self.model.named_parameters(): + if param.requires_grad is True and "trainable" not in name: + param.register_hook(lambda grad: grad * float(self.model_comm_group_size)) \ No newline at end of file diff --git a/bris/inference.py b/bris/inference.py index 9315a62..8d93a0c 100644 --- a/bris/inference.py +++ b/bris/inference.py @@ -9,8 +9,7 @@ from .checkpoint import Checkpoint from .data.datamodule import DataModule -from .utils import check_anemoi_training -from .writer import CustomWriter +from bris.ddp_strategy import DDPGroupStrategy LOGGER = logging.getLogger(__name__) @@ -32,7 +31,6 @@ def __init__( self.checkpoint = checkpoint self.callbacks = callbacks self.datamodule = datamodule - self.deterministic = self.config.deterministic self.precision = precision self._device = device @@ -50,34 +48,14 @@ def device(self) -> str: else: LOGGER.info(f"Using specified device: {self._device}") return self._device - + @cached_property def strategy(self): - if check_anemoi_training(self.checkpoint._metadata): - LOGGER.info("Anemoi training package found, using its strategy") - from anemoi.training.distributed.strategy import DDPGroupStrategy - try: - return DDPGroupStrategy( - self.config.hardware.num_gpus_per_model, - self.config.dataloader.get( - "read_group_size", self.config.hardware.num_gpus_per_model - ), - static_graph=False, #not self.checkpoint.config.training.accum_grad_batches > 1, - ) - except: - return DDPGroupStrategy( - self.config.hardware.num_gpus_per_model, - static_graph=False - ) - else: - LOGGER.info( - "Anemoi training package not found! Using aifs-mono legacy strategy" - ) - from bris.data.legacy.distributed.strategy import DDPGroupStrategy - return DDPGroupStrategy( - self.config.hardware.num_gpus_per_model, - static_graph=False, #not self.checkpoint.config.training.accum_grad_batches > 1, - ) + return DDPGroupStrategy( + num_gpus_per_model=self.config.hardware.num_gpus_per_model, + read_group_size=1, + static_graph=False, + ) @cached_property diff --git a/bris/routes.py b/bris/routes.py index bc680e1..d3c1a14 100644 --- a/bris/routes.py +++ b/bris/routes.py @@ -39,7 +39,7 @@ def get( for config in routing_config: decoder_index = config["decoder_index"] - domain_index = config["domain"] + domain_index = config["domain_index"] curr_grids = data_module.grids[decoder_index] if domain_index == 0: diff --git a/bris/schema/schema.json b/bris/schema/schema.json index 0fff51f..583c58a 100644 --- a/bris/schema/schema.json +++ b/bris/schema/schema.json @@ -126,14 +126,19 @@ "type": "integer", "minimum": 1 }, - "timestep": { - "description": "What is the timestep between each leadtime? E.g. 6h", - "type": "string" - }, "release_cache": { "description": "This option releases unused cached/memory used by torch", "type": "boolean" }, + "workdir": { + "description": "Path to work directory", + "type": "string" + }, + "dataset": { + "description": "Input dataset. Can be a single path, or two datasets merged by cutout. Example dataset.cutout = ['/path/dataset1', '/path/dataset2' ]", + "min_distance_km": {"type": "int", "minimum": 0}, + "adjust": {"type": "string"} + }, "routing": { "type": "array", "items": { @@ -154,5 +159,5 @@ } } }, - "required": ["leadtimes", "frequency", "timestep", "release_cache"] + "required": ["leadtimes", "frequency", "release_cache", "dataset", "workdir"] } diff --git a/bris/utils.py b/bris/utils.py index c87f3f8..2db983e 100644 --- a/bris/utils.py +++ b/bris/utils.py @@ -77,7 +77,7 @@ def check_anemoi_dataset_version(metadata) -> tuple[bool, str]: def create_config(parser: ArgumentParser) -> OmegaConf: args, _ = parser.parse_known_args() - validate(args.config) + validate(args.config, raise_on_error=True) try: config = OmegaConf.load(args.config) @@ -88,11 +88,16 @@ def create_config(parser: ArgumentParser) -> OmegaConf: parser.add_argument( "-c", type=str, dest="checkpoint_path", default=config.checkpoint_path ) - parser.add_argument("-sd", type=str, dest="start_date", default=config.start_date) + parser.add_argument("-sd", type=str, dest="start_date", required=False, + default=config.start_date if "start_date" in config else None) parser.add_argument("-ed", type=str, dest="end_date", default=config.end_date) parser.add_argument( "-p", type=str, dest="dataset_path", help="Path to dataset", default=None ) + parser.add_argument( + "-wd", type=str, dest="workdir", help="Path to work directory", required=False, + default=config.workdir if "workdir" in config else None + ) parser.add_argument( "-pc", @@ -106,7 +111,6 @@ def create_config(parser: ArgumentParser) -> OmegaConf: # TODO: Logic that can add dataset or cutout dataset to the dataloader config parser.add_argument("-f", type=str, dest="frequency", default=config.frequency) - parser.add_argument("-s", type=str, dest="timestep", default=config.timestep) parser.add_argument("-l", type=int, dest="leadtimes", default=config.leadtimes) args = parser.parse_args() @@ -133,7 +137,7 @@ def validate(filename, raise_on_error=False): with open(filename) as file: config = yaml.safe_load(file) try: - q = jsonschema.validate(instance=config, schema=schema) + jsonschema.validate(instance=config, schema=schema) except jsonschema.exceptions.ValidationError as e: if raise_on_error: raise @@ -141,8 +145,72 @@ def validate(filename, raise_on_error=False): print("WARNING: Schema does not validate") print(e) - def recursive_list_to_tuple(data): if isinstance(data, list): return tuple(recursive_list_to_tuple(item) for item in data) return data + +def get_usable_indices( + missing_indices: set[int] | None, + series_length: int, + rollout: int, + multistep: int, + timeincrement: int = 1, +) -> np.ndarray: + """Get the usable indices of a series whit missing indices. + + Parameters + ---------- + missing_indices : set[int] + Dataset to be used. + series_length : int + Length of the series. + rollout : int + Number of steps to roll out. + multistep : int + Number of previous indices to include as predictors. + timeincrement : int + Time increment, by default 1. + + Returns + ------- + usable_indices : np.array + Array of usable indices. + """ + prev_invalid_dates = (multistep - 1) * timeincrement + next_invalid_dates = rollout * timeincrement + + usable_indices = np.arange(series_length) # set of all indices + + if missing_indices is None: + missing_indices = set() + + missing_indices |= {-1, series_length} # to filter initial and final indices + + # Missing indices + for i in missing_indices: + usable_indices = usable_indices[ + (usable_indices < i - next_invalid_dates) + (usable_indices > i + prev_invalid_dates) + ] + + return usable_indices + +def get_base_seed(env_var_list=("AIFS_BASE_SEED", "SLURM_JOB_ID")) -> int: + """Gets the base seed from the environment variables. + + Option to manually set a seed via export AIFS_BASE_SEED=xxx in job script + """ + base_seed = None + for env_var in env_var_list: + if env_var in os.environ: + base_seed = int(os.environ.get(env_var)) + break + + assert ( + base_seed is not None + ), f"Base seed not found in environment variables {env_var_list}" + + if base_seed < 1000: + base_seed = base_seed * 1000 # make it (hopefully) big enough + + return base_seed diff --git a/config.yaml b/config.yaml index f086838..663c584 100644 --- a/config.yaml +++ b/config.yaml @@ -60,14 +60,14 @@ checkpoints: routing: - decoder_index: 0 - domain: 0 + domain_index: 0 outputs: - netcdf: filename_pattern: pred_%Y%m%dT%HZ.nc variables: [2t, 2d] # - decoder_index: 0 -# domain: 1 +# domain_index: 1 # outputs: # - verif: # filename: global/2t/%R.nc # global/2t/legendary_gnome.nc @@ -87,7 +87,7 @@ routing: # path: mslp.nc # # - decoder_index: 1 -# domain: 0 +# domain_index: 0 # outputs: # - netcdf: # filename_pattern: netatmo_%Y%m%dT%HZ.nc diff --git a/config/legacy_legendary_gnome.yaml b/config/legacy_legendary_gnome.yaml index e029f5e..e563b9f 100644 --- a/config/legacy_legendary_gnome.yaml +++ b/config/legacy_legendary_gnome.yaml @@ -42,7 +42,7 @@ dataloader: adjust: all datamodule: - _target_: anemoi.training.data.dataset.NativeGridDataset + _target_: bris.data.dataset.NativeGridDataset _convert_: all hardware: @@ -72,13 +72,13 @@ checkpoints: routing: - decoder_index: 0 - domain: 0 + domain_index: 0 outputs: - netcdf: filename_pattern: meps_pred_%Y%m%dT%HZ.nc variables: [2t, 2d] - decoder_index: 0 - domain: 1 + domain_index: 1 outputs: - netcdf: filename_pattern: era_pred_%Y%m%dT%HZ.nc diff --git a/config/n320.yaml b/config/n320.yaml index 2fef328..311f564 100644 --- a/config/n320.yaml +++ b/config/n320.yaml @@ -37,7 +37,7 @@ dataloader: reorder: ${reorder} datamodule: - _target_: anemoi.training.data.dataset.NativeGridDataset #anemoi.training.data.dataset.ZipDataset + _target_: bris.data.dataset.NativeGridDataset #anemoi.training.data.dataset.ZipDataset _convert_: all hardware: @@ -66,7 +66,7 @@ checkpoints: routing: - decoder_index: 0 - domain: 0 + domain_index: 0 outputs: - netcdf: filename_pattern: /pfs/lustrep4/scratch/project_465001383/haugenha/anemoi-training-ref-updated/run-anemoi/lumi/predictions/n320_pred_%Y%m%dT%HZ.nc diff --git a/config/o96.yaml b/config/o96.yaml index 8f191da..8fb3ea0 100644 --- a/config/o96.yaml +++ b/config/o96.yaml @@ -37,7 +37,7 @@ dataloader: reorder: ${reorder} datamodule: - _target_: anemoi.training.data.dataset.NativeGridDataset #anemoi.training.data.dataset.ZipDataset + _target_: bris.data.dataset.NativeGridDataset #anemoi.training.data.dataset.ZipDataset _convert_: all hardware: @@ -65,7 +65,7 @@ checkpoints: routing: - decoder_index: 0 - domain: 0 + domain_index: 0 outputs: - netcdf: filename_pattern: era_pred_%Y%m%dT%HZ.nc @@ -88,7 +88,7 @@ routing: # path: mslp.nc # # - decoder_index: 1 -# domain: 0 +# domain_index: 0 # outputs: # - netcdf: # filename_pattern: netatmo_%Y%m%dT%HZ.nc diff --git a/config/o96_10k_stretched_grid.yaml b/config/o96_10k_stretched_grid.yaml index 9276590..8d88f3b 100644 --- a/config/o96_10k_stretched_grid.yaml +++ b/config/o96_10k_stretched_grid.yaml @@ -53,7 +53,7 @@ dataloader: end: ${end_date} datamodule: - _target_: anemoi.training.data.dataset.NativeGridDataset + _target_: bris.data.dataset.NativeGridDataset _convert_: all hardware: @@ -82,13 +82,13 @@ checkpoints: routing: - decoder_index: 0 - domain: 0 + domain_index: 0 outputs: - netcdf: filename_pattern: fmi_meps_pred_%Y%m%dT%HZ.nc variables: [2t, 2d] - decoder_index: 0 - domain: 1 + domain_index: 1 outputs: - netcdf: filename_pattern: fmi_era_pred_%Y%m%dT%HZ.nc diff --git a/config/o96_metno.yaml b/config/o96_metno.yaml index ceee315..3d19fd4 100644 --- a/config/o96_metno.yaml +++ b/config/o96_metno.yaml @@ -38,7 +38,7 @@ dataloader: reorder: ${reorder} datamodule: - _target_: anemoi.training.data.dataset.NativeGridDataset #anemoi.training.data.dataset.ZipDataset + _target_: bris.data.dataset.NativeGridDataset #anemoi.training.data.dataset.ZipDataset _convert_: all hardware: @@ -58,7 +58,7 @@ model: routing: - decoder_index: 0 - domain: 0 + domain_index: 0 outputs: - netcdf: filename_pattern: output/%R/era_pred_%Y%m%dT%HZ.nc diff --git a/config/stretched-grid.yaml b/config/stretched-grid.yaml index 7457b55..8730bd4 100644 --- a/config/stretched-grid.yaml +++ b/config/stretched-grid.yaml @@ -46,7 +46,7 @@ dataloader: adjust: all datamodule: - _target_: anemoi.training.data.dataset.NativeGridDataset + _target_: bris.data.dataset.NativeGridDataset _convert_: all hardware: @@ -76,13 +76,13 @@ checkpoints: routing: - decoder_index: 0 - domain: 0 + domain_index: 0 outputs: - netcdf: filename_pattern: meps_pred_%Y%m%dT%HZ.nc variables: [2t, 2d] - decoder_index: 0 - domain: 1 + domain_index: 1 outputs: - netcdf: filename_pattern: era_pred_%Y%m%dT%HZ.nc diff --git a/tests/test_routes.py b/tests/test_routes.py index 74c2c67..a1053e4 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -1,5 +1,4 @@ import os -import numpy as np import bris.routes @@ -36,7 +35,7 @@ def test_get(): config += [ { "decoder_index": 0, - "domain": 0, + "domain_index": 0, "outputs": [ { "verif": { @@ -52,7 +51,7 @@ def test_get(): }, { "decoder_index": 0, - "domain": 1, + "domain_index": 1, "outputs": [ { "netcdf": { @@ -63,7 +62,7 @@ def test_get(): }, { "decoder_index": 1, - "domain": 0, + "domain_index": 0, "outputs": [ { "netcdf": { @@ -85,7 +84,7 @@ def test_get(): variable_indices = bris.routes.get_variable_indices(config, data_module) assert variable_indices == {0: [0, 1, 2], 1: [1]} - routes = bris.routes.get(config, len(leadtimes), num_members, data_module, workdir) + _ = bris.routes.get(config, len(leadtimes), num_members, data_module, workdir) if __name__ == "__main__":