Skip to content
This repository has been archived by the owner on Feb 3, 2025. It is now read-only.

Commit

Permalink
Merge branch 'develop' into feature/model_freezing
Browse files Browse the repository at this point in the history
  • Loading branch information
icedoom888 committed Dec 17, 2024
2 parents 6aac548 + 90978df commit a7ab588
Show file tree
Hide file tree
Showing 16 changed files with 330 additions and 33 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Keep it human-readable, your future self will thank you!
- Dont crash when using the profiler if certain env vars arent set [#180](https://github.com/ecmwf/anemoi-training/pull/180)
- Remove saving of metadata to training checkpoint [#57](https://github.com/ecmwf/anemoi-training/pull/190)
- Fixes to callback plots [#182] (power spectrum large numpy array error + precip cmap for cases where precip is prognostic).
- Identify stretched grid models based on graph rather than configuration file [#204](https://github.com/ecmwf/anemoi-training/pull/204)

### Added

Expand All @@ -24,12 +25,18 @@ Keep it human-readable, your future self will thank you!
- Added a check for the variable sorting on pre-trained/finetuned models [#120](https://github.com/ecmwf/anemoi-training/pull/120)
- <b> Model Freezing ❄️</b>: enabled new functionality. You can now Freeze parts of your model by specifying a list of submodules to freeze with the new config parameter: submodules_to_freeze.
- Introduce new variable to configure: submodules_to_freeze -> List[str], list of submodules to freeze.
- Added new metrics for stretched grid models to track losses inside/outside the regional domain [#199](https://github.com/ecmwf/anemoi-training/pull/199)

### Changed

### Removed
- Removed the resolution config entry [#120](https://github.com/ecmwf/anemoi-training/pull/120)

### Added

- Add supporting arrrays (numpy) to checkpoint
- Support for masking out unconnected nodes in LAM [#171](https://github.com/ecmwf/anemoi-training/pull/171)

## [0.3.1 - AIFS v0.3 Compatibility](https://github.com/ecmwf/anemoi-training/compare/0.3.0...0.3.1) - 2024-11-28

### Changed
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# anemoi-training

[![Documentation Status](https://readthedocs.org/projects/anemoi-training/badge/?version=latest)](https://anemoi-training.readthedocs.io/en/latest/?badge=latest)


**DISCLAIMER**
This project is **BETA** and will be **Experimental** for the foreseeable future.
Interfaces and functionality are likely to change, and the project itself may be scrapped.
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

author = "Anemoi contributors"

year = datetime.datetime.now(tz="UTC").year
year = datetime.datetime.now(tz=datetime.timezone.utc).year
years = "2024" if year == 2024 else f"2024-{year}"

copyright = f"{years}, Anemoi contributors" # noqa: A001
Expand Down
6 changes: 6 additions & 0 deletions src/anemoi/training/config/dataloader/native_grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ limit_batches:
test: 20
predict: 20

# set a custom mask for grid points.
# Useful for LAM (dropping unconnected nodes from forcing dataset)
grid_indices:
_target_: anemoi.training.data.grid_indices.FullGrid
nodes_name: ${graph.data}

# ============
# Dataloader definitions
# These follow the anemoi-datasets patterns
Expand Down
24 changes: 19 additions & 5 deletions src/anemoi/training/config/graph/limited_area.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ nodes:
node_builder:
_target_: anemoi.graphs.nodes.ZarrDatasetNodes
dataset: ${dataloader.training.dataset}
attributes: ${graph.attributes.nodes}
attributes: ${graph.attributes.data_nodes}
# Hidden nodes
hidden:
node_builder:
_target_: anemoi.graphs.nodes.LimitedAreaTriNodes # options: ZarrDatasetNodes, NPZFileNodes, TriNodes
resolution: 5 # grid resolution for npz (o32, o48, ...)
reference_node_name: ${graph.data}
mask_attr_name: cutout
mask_attr_name: cutout_mask

edges:
# Encoder configuration
Expand All @@ -26,6 +26,9 @@ edges:
edge_builders:
- _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges
cutoff_factor: 0.6 # only for cutoff method
- _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges
cutoff_factor: 2 # only for cutoff method
source_mask_attr_name: boundary_mask
attributes: ${graph.attributes.edges}
# Processor configuration
- source_name: ${graph.hidden}
Expand All @@ -39,18 +42,29 @@ edges:
target_name: ${graph.data}
edge_builders:
- _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges
target_mask_attr_name: cutout
target_mask_attr_name: cutout_mask
num_nearest_neighbours: 3 # only for knn method
attributes: ${graph.attributes.edges}


post_processors:
- _target_: anemoi.graphs.processors.RemoveUnconnectedNodes
nodes_name: data
ignore: cutout_mask # optional
save_mask_indices_to_attr: indices_connected_nodes # optional


attributes:
nodes:
data_nodes:
area_weight:
_target_: anemoi.graphs.nodes.attributes.AreaWeights # options: Area, Uniform
norm: unit-max # options: l1, l2, unit-max, unit-sum, unit-std
cutout:
cutout_mask:
_target_: anemoi.graphs.nodes.attributes.CutOutMask
boundary_mask:
_target_: anemoi.graphs.nodes.attributes.BooleanNot
masks:
_target_: anemoi.graphs.nodes.attributes.CutOutMask
edges:
edge_length:
_target_: anemoi.graphs.edges.attributes.EdgeLength
Expand Down
23 changes: 22 additions & 1 deletion src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import logging
from functools import cached_property
from typing import TYPE_CHECKING
from typing import Callable

import pytorch_lightning as pl
from anemoi.datasets.data import open_dataset
from anemoi.models.data_indices.collection import IndexCollection
from anemoi.utils.dates import frequency_to_seconds
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
Expand All @@ -25,11 +28,16 @@

LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from torch_geometric.data import HeteroData

from anemoi.training.data.grid_indices import BaseGridIndices


class AnemoiDatasetsDataModule(pl.LightningDataModule):
"""Anemoi Datasets data module for PyTorch Lightning."""

def __init__(self, config: DictConfig) -> None:
def __init__(self, config: DictConfig, graph_data: HeteroData) -> None:
"""Initialize Anemoi Datasets data module.
Parameters
Expand All @@ -41,6 +49,7 @@ def __init__(self, config: DictConfig) -> None:
super().__init__()

self.config = config
self.graph_data = graph_data

# Set the maximum rollout to be expected
self.rollout = (
Expand Down Expand Up @@ -68,10 +77,21 @@ def statistics(self) -> dict:
def metadata(self) -> dict:
return self.ds_train.metadata

@cached_property
def supporting_arrays(self) -> dict:
return self.ds_train.supporting_arrays | self.grid_indices.supporting_arrays

@cached_property
def data_indices(self) -> IndexCollection:
return IndexCollection(self.config, self.ds_train.name_to_index)

@cached_property
def grid_indices(self) -> type[BaseGridIndices]:
reader_group_size = self.config.dataloader.get("read_group_size", self.config.hardware.num_gpus_per_model)
grid_indices = instantiate(self.config.dataloader.grid_indices, reader_group_size=reader_group_size)
grid_indices.setup(self.graph_data)
return grid_indices

@cached_property
def timeincrement(self) -> int:
"""Determine the step size relative to the data frequency."""
Expand Down Expand Up @@ -164,6 +184,7 @@ def _get_dataset(
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
shuffle=shuffle,
grid_indices=self.grid_indices,
label=label,
effective_bs=effective_bs,
)
Expand Down
32 changes: 17 additions & 15 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import random
from functools import cached_property
from typing import TYPE_CHECKING
from typing import Callable

import numpy as np
Expand All @@ -26,13 +27,17 @@

LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from anemoi.training.data.grid_indices import BaseGridIndices


class NativeGridDataset(IterableDataset):
"""Iterable dataset for AnemoI data on the arbitrary grids."""

def __init__(
self,
data_reader: Callable,
grid_indices: type[BaseGridIndices],
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
Expand All @@ -46,6 +51,8 @@ def __init__(
----------
data_reader : Callable
user function that opens and returns the zarr array data
grid_indices : Type[BaseGridIndices]
indices of the grid to keep. Defaults to None, which keeps all spatial indices.
rollout : int, optional
length of rollout window, by default 12
timeincrement : int, optional
Expand All @@ -66,6 +73,7 @@ def __init__(

self.rollout = rollout
self.timeincrement = timeincrement
self.grid_indices = grid_indices

# lazy init
self.n_samples_per_epoch_total: int = 0
Expand All @@ -90,8 +98,6 @@ def __init__(
assert self.multi_step > 0, "Multistep value must be greater than zero."
self.ensemble_dim: int = 2
self.ensemble_size = self.data.shape[self.ensemble_dim]
self.grid_dim: int = -1
self.grid_size = self.data.shape[self.grid_dim]

@cached_property
def statistics(self) -> dict:
Expand All @@ -103,6 +109,11 @@ def metadata(self) -> dict:
"""Return dataset metadata."""
return self.data.metadata()

@cached_property
def supporting_arrays(self) -> dict:
"""Return dataset supporting_arrays."""
return self.data.supporting_arrays()

@cached_property
def name_to_index(self) -> dict:
"""Return dataset statistics."""
Expand Down Expand Up @@ -160,14 +171,7 @@ def set_comm_group_info(
self.reader_group_rank = reader_group_rank
self.reader_group_size = reader_group_size

if self.reader_group_size > 1:
# get the grid shard size and start/end indices
grid_shard_size = self.grid_size // self.reader_group_size
self.grid_start = self.reader_group_rank * grid_shard_size
if self.reader_group_rank == self.reader_group_size - 1:
self.grid_end = self.grid_size
else:
self.grid_end = (self.reader_group_rank + 1) * grid_shard_size
assert self.reader_group_size >= 1, "reader_group_size must be positive"

LOGGER.debug(
"NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, "
Expand Down Expand Up @@ -274,11 +278,9 @@ def __iter__(self) -> torch.Tensor:
start = i - (self.multi_step - 1) * self.timeincrement
end = i + (self.rollout + 1) * self.timeincrement

if self.reader_group_size > 1: # read only a subset of the grid
x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end]
else: # read the full grid
x = self.data[start : end : self.timeincrement, :, :, :]

grid_shard_indices = self.grid_indices.get_shard_indices(self.reader_group_rank)
x = self.data[start : end : self.timeincrement, :, :, :]
x = x[..., grid_shard_indices] # select the grid shard
x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

Expand Down
96 changes: 96 additions & 0 deletions src/anemoi/training/data/grid_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import logging
from abc import ABC
from abc import abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING
from typing import Union

import numpy as np

if TYPE_CHECKING:
from torch_geometric.data import HeteroData

LOGGER = logging.getLogger(__name__)

ArrayIndex = Union[slice, int, Sequence[int]]


class BaseGridIndices(ABC):
"""Base class for custom grid indices."""

def __init__(self, nodes_name: str, reader_group_size: int) -> None:
self.nodes_name = nodes_name
self.reader_group_size = reader_group_size

def setup(self, graph: HeteroData) -> None:
self.grid_size = self.compute_grid_size(graph)

def split_seq_in_shards(self, reader_group_rank: int) -> tuple[int, int]:
"""Get the indices to split a sequence into equal size shards."""
grid_shard_size = self.grid_size // self.reader_group_size
grid_start = reader_group_rank * grid_shard_size
if reader_group_rank == self.reader_group_size - 1:
grid_end = self.grid_size
else:
grid_end = (reader_group_rank + 1) * grid_shard_size

return slice(grid_start, grid_end)

@property
def supporting_arrays(self) -> dict:
return {}

@abstractmethod
def compute_grid_size(self, graph: HeteroData) -> int: ...

@abstractmethod
def get_shard_indices(self, reader_group_rank: int) -> ArrayIndex: ...


class FullGrid(BaseGridIndices):
"""The full grid is loaded."""

def compute_grid_size(self, graph: HeteroData) -> int:
return graph[self.nodes_name].num_nodes

def get_shard_indices(self, reader_group_rank: int) -> ArrayIndex:
return self.split_seq_in_shards(reader_group_rank)


class MaskedGrid(BaseGridIndices):
"""Grid is masked based on a node attribute."""

def __init__(self, nodes_name: str, reader_group_size: int, node_attribute_name: str):
super().__init__(nodes_name, reader_group_size)
self.node_attribute_name = node_attribute_name

def setup(self, graph: HeteroData) -> None:
LOGGER.info(
"The graph attribute %s of the %s nodes will be used to masking the spatial dimension.",
self.node_attribute_name,
self.nodes_name,
)
self.grid_indices = graph[self.nodes_name][self.node_attribute_name].squeeze().tolist()
super().setup(graph)

@property
def supporting_arrays(self) -> dict:
return {"grid_indices": np.array(self.grid_indices, dtype=np.int64)}

def compute_grid_size(self, _graph: HeteroData) -> int:
return len(self.grid_indices)

def get_shard_indices(self, reader_group_rank: int) -> ArrayIndex:
sequence_indices = self.split_seq_in_shards(reader_group_rank)
return self.grid_indices[sequence_indices]
Loading

0 comments on commit a7ab588

Please sign in to comment.