Skip to content
This repository has been archived by the owner on Feb 3, 2025. It is now read-only.

Commit

Permalink
Avoid circular imports, cleanup interpolator
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Jan 6, 2025
1 parent a09eb4b commit 3f52e40
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 125 deletions.
131 changes: 29 additions & 102 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import os
import random
from functools import cached_property
from typing import TYPE_CHECKING
from typing import Callable

import numpy as np
Expand All @@ -27,42 +26,26 @@

LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from anemoi.training.data.grid_indices import BaseGridIndices


class NativeGridDataset(IterableDataset):
"""Iterable dataset for AnemoI data on the arbitrary grids."""

def __init__(
self,
data_reader: Callable,
grid_indices: type[BaseGridIndices],
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
relative_date_indices: list = [0, 1, 2],
relative_date_indices: list = [0,1,2],
model_comm_group_rank: int = 0,
model_comm_group_id: int = 0,
model_comm_num_groups: int = 1,
shuffle: bool = True,
label: str = "generic",
effective_bs: int = 1,
) -> None:
"""Initialize (part of) the dataset state.
Parameters
----------
data_reader : Callable
user function that opens and returns the zarr array data
grid_indices : Type[BaseGridIndices]
indices of the grid to keep. Defaults to None, which keeps all spatial indices.
rollout : int, optional
length of rollout window, by default 12
timeincrement : int, optional
time increment between samples, by default 1
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
relative_date_indices : list
list of time indices to load from the data relative to the current sample i in __iter__
model_comm_group_rank : int, optional
Expand All @@ -75,30 +58,21 @@ def __init__(
Shuffle batches, by default True
label : str, optional
label for the dataset, by default "generic"
effective_bs : int, default 1
effective batch size useful to compute the lenght of the dataset
"""
self.label = label
self.effective_bs = effective_bs

self.data = data_reader

self.rollout = rollout
self.timeincrement = timeincrement
self.grid_indices = grid_indices

# lazy init
self.n_samples_per_epoch_total: int = 0
self.n_samples_per_epoch_per_worker: int = 0

# lazy init model and reader group info, will be set by the DDPGroupStrategy:
self.model_comm_group_rank = 0
self.model_comm_num_groups = 1
self.model_comm_group_id = 0
self.global_rank = 0

self.reader_group_rank = 0
self.reader_group_size = 1
# DDP-relevant info
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.model_comm_group_id = model_comm_group_id
self.global_rank = int(os.environ.get("SLURM_PROCID", "0"))

# additional state vars (lazy init)
self.n_samples_per_worker = 0
Expand All @@ -122,11 +96,6 @@ def metadata(self) -> dict:
"""Return dataset metadata."""
return self.data.metadata()

@cached_property
def supporting_arrays(self) -> dict:
"""Return dataset supporting_arrays."""
return self.data.supporting_arrays()

@cached_property
def name_to_index(self) -> dict:
"""Return dataset statistics."""
Expand All @@ -149,57 +118,7 @@ def valid_date_indices(self) -> np.ndarray:
dataset length minus rollout minus additional multistep inputs
(if time_increment is 1).
"""
return get_usable_indices(
self.data.missing,
len(self.data),
np.array(self.relative_date_indices, dtype=np.int64),
self.data.model_run_ids,
)

def set_comm_group_info(
self,
global_rank: int,
model_comm_group_id: int,
model_comm_group_rank: int,
model_comm_num_groups: int,
reader_group_rank: int,
reader_group_size: int,
) -> None:
"""Set model and reader communication group information (called by DDPGroupStrategy).
Parameters
----------
global_rank : int
Global rank
model_comm_group_id : int
Model communication group ID
model_comm_group_rank : int
Model communication group rank
model_comm_num_groups : int
Number of model communication groups
reader_group_rank : int
Reader group rank
reader_group_size : int
Reader group size
"""
self.global_rank = global_rank
self.model_comm_group_id = model_comm_group_id
self.model_comm_group_rank = model_comm_group_rank
self.model_comm_num_groups = model_comm_num_groups
self.reader_group_rank = reader_group_rank
self.reader_group_size = reader_group_size

assert self.reader_group_size >= 1, "reader_group_size must be positive"

LOGGER.debug(
"NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, "
"model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d",
global_rank,
model_comm_group_id,
model_comm_group_rank,
model_comm_num_groups,
reader_group_rank,
)
return get_usable_indices(self.data.missing, len(self.data), np.array(self.relative_date_indices, dtype=np.int64), self.data.model_run_ids)

def per_worker_init(self, n_workers: int, worker_id: int) -> None:
"""Called by worker_init_func on each copy of dataset.
Expand All @@ -226,7 +145,6 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None:

low = shard_start + worker_id * self.n_samples_per_worker
high = min(shard_start + (worker_id + 1) * self.n_samples_per_worker, shard_end)
self.chunk_index_range = np.arange(low, high, dtype=np.uint32)

LOGGER.debug(
"Worker %d (pid %d, global_rank %d, model comm group %d) has low/high range %d / %d",
Expand All @@ -238,17 +156,27 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None:
high,
)

self.chunk_index_range = self.valid_date_indices[np.arange(low, high, dtype=np.uint32)]

# each worker must have a different seed for its random number generator,
# otherwise all the workers will output exactly the same data
# should we check lightning env variable "PL_SEED_WORKERS" here?
# but we alwyas want to seed these anyways ...

base_seed = get_base_seed()

torch.manual_seed(base_seed)
random.seed(base_seed)
self.rng = np.random.default_rng(seed=base_seed)
seed = (
base_seed * (self.model_comm_group_id + 1) - worker_id
) # note that test, validation etc. datasets get same seed
torch.manual_seed(seed)
random.seed(seed)
self.rng = np.random.default_rng(seed=seed)
sanity_rnd = self.rng.random(1)

LOGGER.debug(
(
"Worker %d (%s, pid %d, glob. rank %d, model comm group %d, "
"group_rank %d, base_seed %d), sanity rnd %f"
"group_rank %d, base_seed %d) using seed %d, sanity rnd %f"
),
worker_id,
self.label,
Expand All @@ -257,6 +185,7 @@ def per_worker_init(self, n_workers: int, worker_id: int) -> None:
self.model_comm_group_id,
self.model_comm_group_rank,
base_seed,
seed,
sanity_rnd,
)

Expand All @@ -271,12 +200,12 @@ def __iter__(self) -> torch.Tensor:
"""
if self.shuffle:
shuffled_chunk_indices = self.rng.choice(
self.valid_date_indices,
size=len(self.valid_date_indices),
self.chunk_index_range,
size=self.n_samples_per_worker,
replace=False,
)[self.chunk_index_range]
)
else:
shuffled_chunk_indices = self.valid_date_indices[self.chunk_index_range]
shuffled_chunk_indices = self.chunk_index_range

LOGGER.debug(
(
Expand All @@ -293,9 +222,7 @@ def __iter__(self) -> torch.Tensor:
)

for i in shuffled_chunk_indices:
grid_shard_indices = self.grid_indices.get_shard_indices(self.reader_group_rank)
x = x[..., grid_shard_indices] # select the grid shard
x = self.data[self.relative_date_indices + i] # NOTE: this requires an update to anemoi datasets
x = self.data[self.relative_date_indices + i] #NOTE: this requires an update to anemoi datasets
x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

Expand Down Expand Up @@ -333,4 +260,4 @@ def worker_init_func(worker_id: int) -> None:
dataset_obj.per_worker_init(
n_workers=worker_info.num_workers,
worker_id=worker_id,
)
)
23 changes: 2 additions & 21 deletions src/anemoi/training/train/interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,15 @@


import logging
import math
import os
from collections import defaultdict
from collections.abc import Generator
from collections.abc import Mapping
from typing import Optional
from typing import Union
from operator import itemgetter

import numpy as np
import pytorch_lightning as pl
import torch
from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.interface import AnemoiModelInterface
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import OmegaConf
from timm.scheduler import CosineLRScheduler
from torch.distributed.distributed_c10d import ProcessGroup
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.training.losses.utils import grad_scaler
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.masks import Boolean1DMask
from anemoi.training.utils.masks import NoOutputMask

from anemoi.training.train.forecaster import GraphForecaster

Expand All @@ -54,6 +34,7 @@ def __init__(
statistics: dict,
data_indices: IndexCollection,
metadata: dict,
supporting_arrays: dict
) -> None:
"""Initialize graph neural network interpolator.
Expand All @@ -71,7 +52,7 @@ def __init__(
Provenance information
"""
super().__init__(config = config, graph_data = graph_data, statistics = statistics, data_indices = data_indices, metadata = metadata)
super().__init__(config = config, graph_data = graph_data, statistics = statistics, data_indices = data_indices, metadata = metadata, supporting_arrays=supporting_arrays)
self.target_forcing_indices = itemgetter(*config.training.target_forcing.data)(data_indices.data.input.name_to_index)
if type(self.target_forcing_indices) == int:
self.target_forcing_indices = [self.target_forcing_indices]
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/training/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import torch.nn as nn
from anemoi.utils.checkpoints import save_metadata

from anemoi.training.train.forecaster import GraphForecaster

LOGGER = logging.getLogger(__name__)


Expand All @@ -35,6 +33,8 @@ def load_and_prepare_model(lightning_checkpoint_path: str) -> tuple[torch.nn.Mod
pytorch model, metadata
"""
from anemoi.training.train.forecaster import GraphForecaster

module = GraphForecaster.load_from_checkpoint(lightning_checkpoint_path)
model = module.model

Expand Down

0 comments on commit 3f52e40

Please sign in to comment.