From dd9883fbfe253ab3ed26c75f5ebe080bbcb82e6f Mon Sep 17 00:00:00 2001 From: HCookie <48088699+HCookie@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:56:31 +0000 Subject: [PATCH 01/16] [changelog update] Update to 0.3.0 --- CHANGELOG.md | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21ef6ff3..39fd4471 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,22 +8,36 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! -## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.2.2...HEAD) +## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD) +### Fixed + +### Added + +### Changed +## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14 ### Fixed + - Rename loss_scaling to variable_loss_scaling [#138](https://github.com/ecmwf/anemoi-training/pull/138) - Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) - - Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115) - - Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119) + - Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115) + - Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119) + - Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87) - - Enable longer validation rollout than training + - Enable longer validation rollout than training + - Expand iterables in logging [#91](https://github.com/ecmwf/anemoi-training/pull/91) - - Save entire config in mlflow + - Save entire config in mlflow + + ### Added + - Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70) - - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) + - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) + - Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137) - - Add without subsetting in ScaleTensor + - Add without subsetting in ScaleTensor + - Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63) - Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92) - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) @@ -31,19 +45,17 @@ Keep it human-readable, your future self will thank you! - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) ### Changed + - Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118) - Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67) - Merged node & edge trainable feature callbacks into one. [#135](https://github.com/ecmwf/anemoi-training/pull/135) ## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28 - ### Changed - Lock python version <3.13 [#107](https://github.com/ecmwf/anemoi-training/pull/107) - - ## [0.2.1 - Bugfix: resuming mlflow runs](https://github.com/ecmwf/anemoi-training/compare/0.2.0...0.2.1) - 2024-10-24 ### Added @@ -85,6 +97,7 @@ Keep it human-readable, your future self will thank you! - Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) + #### Functionality - Enable the callback for plotting a histogram for variables containing NaNs @@ -96,7 +109,6 @@ Keep it human-readable, your future self will thank you! - Feature: `AnemoiMlflowClient`, an mlflow client with authentication support [#86](https://github.com/ecmwf/anemoi-training/pull/86) - Long Rollout Plots - ### Fixed - Fix `TypeError` raised when trying to JSON serialise `datetime.timedelta` object - [#43](https://github.com/ecmwf/anemoi-training/pull/43) From d0a8866b0d4dfb31665742978858f77c64e16a9a Mon Sep 17 00:00:00 2001 From: Mariana Clare <31656450+mc4117@users.noreply.github.com> Date: Fri, 15 Nov 2024 08:53:09 +0000 Subject: [PATCH 02/16] Fix/async callbacks (#102) * Refactor Callbacks - Split into seperate files - Use list in config to add callbacks - Provide legacy config enabled approach - Fix ruff issues * Update changelog * Fix TypeError * Move to hydra.instantiate * Add __all__ * Add to base config * Fix nested list * Fix nested get issue * Fix type checking * feat: edge plot in callbacks * feat: set default extra callbacks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: typing & refactoring * fix: remove list comprehension * Refactor according to PR - Prefill config with callbacks - Warn on deprecations for old config - Expand config enabled - Add back SWA - Fix logging callback - Add flag to disable checkpointing - Add testing * Update deprecation warning * Refactor: Remove backwards compatability, - Split plots - Rename, lr to optimiser - Refactor plotting callbacks to be more init config * Fix tests * PR Fixes - Remove enabled from plotting callbacks - Connect sample_idx in config * Update Changelog * Refactor rollout (#87) Refactor rollout logic * Remove batch frequency from LongRolloutPlots * Remove TP reference * Remove missing config reference * Authentication support for mlflow sync (#51) * feat: authentication support for mlflow sync * chore: formatting * chore: changelog * chore: changelog add link * fix: sync authentication flag * refactor: move `health_check` to submodule top level * feat: add health check * chore: update error msg * refactor: mlflow utils * New mlflow authentication API (#78) * fix: mlflow auth use web seed token * feat: make target env var an optional argument * chore: docstrings * fix: tests * chore: add comment * chore: changelog * chore: docstring * Update changelog * rebase * Update deprecation warning * Refactor: Remove backwards compatability, - Split plots - Rename, lr to optimiser - Refactor plotting callbacks to be more init config * add scatter plot * adding async * fix * tests * fix failing tests * rm change to ds valid * precommit hooks * fix linting * rebase * Update deprecation warning * Refactor: Remove backwards compatability, - Split plots - Rename, lr to optimiser - Refactor plotting callbacks to be more init config * add scatter plot * adding async * fix * tests * fix failing tests * rm change to ds valid * precommit hooks * fix linting * revert unnecessary config changes * change config files * Swapped histogram and spectrum * Update copyright notice * Fix issues with split of PlotAdditionalMetrics * Fix CHANGELOG * Fix documentation for callbacks * Add all callback submodules to docs * Apply suggestions from code review Co-authored-by: Sara Hahner <44293258+sahahner@users.noreply.github.com> * Fix init args issue in RolloutPlots * Add rollout_eval config * Add training mode to rollout step * Force LongRolloutPlots to plot in serial * Add warning to LongRolloutPlots when async * Fix asserrt calculation * Apply post_processors before plotting in LongRolloutPlots * Fix reference to batch * Fix debug config * brinding plot for mean wave direction and fixing type hinting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add changelog entry * fixes for async plots to work * fix pre-commit styling * improved loop closing and readability * fixing for pre-commit hooks * remove commented block * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address sugestion for args and kwargs and missing type hints * update flag to datashader rather than scatter * update configs * update docs * update comment for readability * update branch * update branch and test --------- Co-authored-by: Harrison Cook Co-authored-by: Mario Santa Cruz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Gert Mertes <13658335+gmertes@users.noreply.github.com> Co-authored-by: Sara Hahner <44293258+sahahner@users.noreply.github.com> Co-authored-by: anaprietonem Co-authored-by: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com> --- CHANGELOG.md | 1 + docs/modules/diagnostics.rst | 22 +- pyproject.toml | 1 + .../config/diagnostics/plot/detailed.yaml | 1 + .../config/diagnostics/plot/simple.yaml | 1 + src/anemoi/training/data/datamodule.py | 3 +- .../training/diagnostics/callbacks/plot.py | 153 +++++++------ src/anemoi/training/diagnostics/maps.py | 2 +- src/anemoi/training/diagnostics/plots.py | 210 +++++++++++++----- 9 files changed, 259 insertions(+), 135 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21ef6ff3..287d76ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you! - Save entire config in mlflow ### Added - Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70) +- Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102) - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) - Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137) - Add without subsetting in ScaleTensor diff --git a/docs/modules/diagnostics.rst b/docs/modules/diagnostics.rst index 28eac7c7..4364e683 100644 --- a/docs/modules/diagnostics.rst +++ b/docs/modules/diagnostics.rst @@ -51,12 +51,32 @@ parameters to plot, as well as the plotting frequency, and asynchronosity. Setting ``config.diagnostics.plot.asynchronous``, means that the model -training doesn't stop whilst the callbacks are being evaluated) +training doesn't stop whilst the callbacks are being evaluated. This is +useful for large models where the plotting can take a long time. The +plotting module uses asynchronous callbacks via `asyncio` and +`concurrent.futures.ThreadPoolExecutor` to handle plotting tasks without +blocking the main application. A dedicated event loop runs in a separate +background thread, allowing plotting tasks to be offloaded to worker +threads. This setup keeps the main thread responsive, handling +plot-related tasks asynchronously and efficiently in the background. + +There is an additional flag in the plotting callbacks to control the +rendering method for geospatial plots, offering a trade-off between +performance and detail. When `datashader` is set to True, Datashader is +used for rendering, which accelerates plotting through efficient +hexbining, particularly useful for large datasets. This approach can +produce smoother-looking plots due to the aggregation of data points. If +`datashader` is set to False, matplotlib.scatter is used, which provides +sharper and more detailed visuals but may be slower for large datasets. + +**Note** - this asynchronous behaviour is only available for the +plotting callbacks. .. code:: yaml plot: asynchronous: True # Whether to plot asynchronously + datashader: True # Whether to use datashader for plotting (faster) frequency: # Frequency of the plotting batch: 750 epoch: 5 diff --git a/pyproject.toml b/pyproject.toml index f3e730d3..8d685acd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "anemoi-graphs>=0.4", "anemoi-models>=0.3", "anemoi-utils[provenance]>=0.4.4", + "datashader>=0.16.3", "einops>=0.6.1", "hydra-core>=1.3", "matplotlib>=3.7.1", diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index b759c17b..d1ac8b0f 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -1,4 +1,5 @@ asynchronous: True # Whether to plot asynchronously +datashader: True # Choose which technique to use for plotting frequency: # Frequency of the plotting batch: 750 epoch: 5 diff --git a/src/anemoi/training/config/diagnostics/plot/simple.yaml b/src/anemoi/training/config/diagnostics/plot/simple.yaml index 2a987ccb..63c805a2 100644 --- a/src/anemoi/training/config/diagnostics/plot/simple.yaml +++ b/src/anemoi/training/config/diagnostics/plot/simple.yaml @@ -1,4 +1,5 @@ asynchronous: True # Whether to plot asynchronously +datashader: True # Choose which technique to use for plotting frequency: # Frequency of the plotting batch: 750 epoch: 10 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 0d3d1b3f..303266fc 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -140,8 +140,7 @@ def ds_train(self) -> NativeGridDataset: @cached_property def ds_valid(self) -> NativeGridDataset: - r = self.rollout - r = max(r, self.config.dataloader.get("validation_rollout", 1)) + r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1)) assert self.config.dataloader.training.end < self.config.dataloader.validation.start, ( f"Training end date {self.config.dataloader.training.end} is not before" diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 869a69fb..171eb840 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -7,13 +7,13 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# ruff: noqa: ANN001 from __future__ import annotations +import asyncio import copy import logging -import sys +import threading import time import traceback from abc import ABC @@ -23,8 +23,6 @@ from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING -from typing import Any -from typing import Callable import matplotlib.patches as mpatches import matplotlib.pyplot as plt @@ -43,33 +41,14 @@ from anemoi.training.losses.weightedloss import BaseWeightedLoss if TYPE_CHECKING: + from typing import Any + import pytorch_lightning as pl from omegaconf import OmegaConf LOGGER = logging.getLogger(__name__) -class ParallelExecutor(ThreadPoolExecutor): - """Wraps parallel execution and provides accurate information about errors. - - Extends ThreadPoolExecutor to preserve the original traceback and line number. - - Reference: https://stackoverflow.com/questions/19309514/getting-original-line- - number-for-exception-in-concurrent-futures/24457608#24457608 - """ - - def submit(self, fn: Any, *args, **kwargs) -> Callable: - """Submits the wrapped function instead of `fn`.""" - return super().submit(self._function_wrapper, fn, *args, **kwargs) - - def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable: - """Wraps `fn` in order to preserve the traceback of any kind of.""" - try: - return fn(*args, **kwargs) - except Exception as exc: - raise sys.exc_info()[0](traceback.format_exc()) from exc - - class BasePlotCallback(Callback, ABC): """Factory for creating a callback that plots data to Experiment Logging.""" @@ -93,11 +72,21 @@ def __init__(self, config: OmegaConf) -> None: self.plot = self._plot self._executor = None + self._error: BaseException = None + self.datashader_plotting = config.diagnostics.plot.datashader if self.config.diagnostics.plot.asynchronous: - self._executor = ParallelExecutor(max_workers=1) - self._error: BaseException | None = None + LOGGER.info("Setting up asynchronous plotting ...") self.plot = self._async_plot + self._executor = ThreadPoolExecutor(max_workers=1) + self.loop_thread = threading.Thread(target=self.start_event_loop, daemon=True) + self.loop_thread.start() + + def start_event_loop(self) -> None: + """Start the event loop in a separate thread.""" + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.loop.run_forever() @rank_zero_only def _output_figure( @@ -113,27 +102,48 @@ def _output_figure( save_path = Path( self.save_basedir, "plots", - f"{tag}_epoch{epoch:03d}.png", + f"{tag}_epoch{epoch:03d}.jpg", ) save_path.parent.mkdir(parents=True, exist_ok=True) - fig.savefig(save_path, dpi=100, bbox_inches="tight") + fig.canvas.draw() + image_array = np.array(fig.canvas.renderer.buffer_rgba()) + plt.imsave(save_path, image_array, dpi=100) if self.config.diagnostics.log.wandb.enabled: import wandb logger.experiment.log({exp_log_tag: wandb.Image(fig)}) - if self.config.diagnostics.log.mlflow.enabled: run_id = logger.run_id logger.experiment.log_artifact(run_id, str(save_path)) plt.close(fig) # cleanup + @rank_zero_only + def _plot_with_error_catching(self, trainer: pl.Trainer, args: Any, kwargs: Any) -> None: + """To execute the plot function but ensuring we catch any errors.""" + try: + self._plot(trainer, *args, **kwargs) + except BaseException: + import os + + LOGGER.exception(traceback.format_exc()) + os._exit(1) # to force exit when sanity val steps are used + def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: - """Method is called to close the threads.""" + """Teardown the callback.""" del trainer, pl_module, stage # unused + LOGGER.info("Teardown of the Plot Callback ...") + if self._executor is not None: - self._executor.shutdown(wait=True) + LOGGER.info("waiting and shutting down the executor ...") + self._executor.shutdown(wait=False, cancel_futures=True) + + self.loop.call_soon_threadsafe(self.loop.stop) + self.loop_thread.join() + # Step 3: Close the asyncio event loop + self.loop_thread._stop() + self.loop_thread._delete() def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor: if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None: @@ -147,31 +157,39 @@ def _plot( self, trainer: pl.Trainer, pl_module: pl.LightningModule, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: """Plotting function to be implemented by subclasses.""" + # Async function to run the plot function in the background thread + async def submit_plot(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None: + """Async function or coroutine to schedule the plot function.""" + loop = asyncio.get_running_loop() + # run_in_executor doesn't support keyword arguments, + await loop.run_in_executor( + self._executor, + self._plot_with_error_catching, + trainer, + args, + kwargs, + ) # because loop.run_in_executor expects positional arguments, not keyword arguments + @rank_zero_only def _async_plot( self, trainer: pl.Trainer, - *args: list, - **kwargs: dict, + *args: Any, + **kwargs: Any, ) -> None: - """To execute the plot function but ensuring we catch any errors.""" - future = self._executor.submit( - self._plot, - trainer, - *args, - **kwargs, - ) - # otherwise the error won't be thrown till the validation epoch is finished - try: - future.result() - except Exception: - LOGGER.exception("Critical error occurred in asynchronous plots.") - sys.exit(1) + """Run the plot function asynchronously. + + This is the function that is called by the callback. It schedules the plot + function to run in the background thread. Since we have an event loop running in + the background thread, we need to schedule the plot function to run in that + loop. + """ + asyncio.run_coroutine_threadsafe(self.submit_plot(trainer, *args, **kwargs), self.loop) class BasePerBatchPlotCallback(BasePlotCallback): @@ -192,26 +210,12 @@ def __init__(self, config: OmegaConf, every_n_batches: int | None = None): super().__init__(config) self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch - @abstractmethod @rank_zero_only - def _plot( + def on_validation_batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - epoch: int, - **kwargs, - ) -> None: - """Plotting function to be implemented by subclasses.""" - - @rank_zero_only - def on_validation_batch_end( - self, - trainer, - pl_module, - output, + output: list[torch.Tensor], batch: torch.Tensor, batch_idx: int, **kwargs, @@ -310,12 +314,12 @@ def __init__( @rank_zero_only def _plot( self, - trainer, + trainer: pl.Trainer, pl_module: pl.LightningModule, output: list[torch.Tensor], batch: torch.Tensor, - batch_idx, - epoch, + batch_idx: int, + epoch: int, ) -> None: _ = output @@ -406,9 +410,9 @@ def _plot( @rank_zero_only def on_validation_batch_end( self, - trainer, - pl_module, - output, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + output: list[torch.Tensor], batch: torch.Tensor, batch_idx: int, ) -> None: @@ -454,7 +458,7 @@ def _plot( _ = epoch model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model - fig = plot_graph_node_features(model) + fig = plot_graph_node_features(model, datashader=self.datashader_plotting) self._output_figure( trainer.logger, @@ -750,6 +754,7 @@ def _plot( data[0, ...].squeeze(), data[rollout_step + 1, ...].squeeze(), output_tensor[rollout_step, ...], + datashader=self.datashader_plotting, precip_and_related_fields=self.precip_and_related_fields, ) @@ -839,7 +844,7 @@ def _plot( self, trainer: pl.Trainer, pl_module: pl.LightningModule, - outputs: list, + outputs: list[torch.Tensor], batch: torch.Tensor, batch_idx: int, epoch: int, @@ -921,7 +926,7 @@ def _plot( self, trainer: pl.Trainer, pl_module: pl.LightningModule, - outputs: list, + outputs: list[torch.Tensor], batch: torch.Tensor, batch_idx: int, epoch: int, diff --git a/src/anemoi/training/diagnostics/maps.py b/src/anemoi/training/diagnostics/maps.py index 338a9059..fcf88921 100644 --- a/src/anemoi/training/diagnostics/maps.py +++ b/src/anemoi/training/diagnostics/maps.py @@ -32,7 +32,7 @@ def __init__(self) -> None: def __call__(self, lon: np.ndarray, lat: np.ndarray) -> tuple[np.ndarray, np.ndarray]: lon_rad = np.radians(lon) lat_rad = np.radians(lat) - x = [v - 2 * np.pi if v > np.pi else v for v in lon_rad] + x = np.array([v - 2 * np.pi if v > np.pi else v for v in lon_rad], dtype=lon_rad.dtype) y = lat_rad return x, y diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index dde80018..d397f05c 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -13,10 +13,13 @@ import logging from typing import TYPE_CHECKING +import datashader as dsh import matplotlib.pyplot as plt import matplotlib.style as mplstyle import numpy as np +import pandas as pd from anemoi.models.layers.mapper import GraphEdgeMixin +from datashader.mpl_ext import dsshow from matplotlib.collections import LineCollection from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap @@ -37,6 +40,7 @@ LOGGER = logging.getLogger(__name__) continents = Coastlines() +LAYOUT = "tight" @dataclass @@ -105,7 +109,7 @@ def plot_loss( # create plot # more space for legend figsize = (8, 3) if legend_patches else (4, 3) - fig, ax = plt.subplots(1, 1, figsize=figsize) + fig, ax = plt.subplots(1, 1, figsize=figsize, layout=LAYOUT) # histogram plot ax.bar(np.arange(x.size), x, color=colors, log=1) @@ -114,8 +118,7 @@ def plot_loss( ax.set_xticks(list(xticks.values()), list(xticks.keys()), rotation=60) if legend_patches: # legend outside and to the right of the plot - plt.legend(handles=legend_patches, bbox_to_anchor=(1.01, 1), loc="upper left") - plt.tight_layout() + ax.legend(handles=legend_patches, bbox_to_anchor=(1.01, 1), loc="upper left") return fig @@ -154,7 +157,7 @@ def plot_power_spectrum( n_plots_x, n_plots_y = len(parameters), 1 figsize = (n_plots_y * 4, n_plots_x * 3) - fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize) + fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) pc = EquirectangularProjection() lat, lon = latlons[:, 0], latlons[:, 1] @@ -217,7 +220,6 @@ def plot_power_spectrum( ax[plot_idx].set_xlabel("$k$") ax[plot_idx].set_ylabel("$P(k)$") ax[plot_idx].set_aspect("auto", adjustable=None) - fig.tight_layout() return fig @@ -285,7 +287,7 @@ def plot_histogram( n_plots_x, n_plots_y = len(parameters), 1 figsize = (n_plots_y * 4, n_plots_x * 3) - fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize) + fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()): yt = y_true[..., variable_idx].squeeze() @@ -325,7 +327,6 @@ def plot_histogram( ax[plot_idx].legend() ax[plot_idx].set_aspect("auto", adjustable=None) - fig.tight_layout() return fig @@ -338,6 +339,7 @@ def plot_predicted_multilevel_flat_sample( x: np.ndarray, y_true: np.ndarray, y_pred: np.ndarray, + datashader: bool = False, precip_and_related_fields: list | None = None, ) -> Figure: """Plots data for one multilevel latlon-"flat" sample. @@ -363,6 +365,8 @@ def plot_predicted_multilevel_flat_sample( Expected data of shape (lat*lon, nvar*level) y_pred : np.ndarray Predicted data of shape (lat*lon, nvar*level) + datashader: bool, optional + Scatter plot, by default False precip_and_related_fields : list, optional List of precipitation-like variables, by default [] @@ -375,7 +379,7 @@ def plot_predicted_multilevel_flat_sample( n_plots_x, n_plots_y = len(parameters), n_plots_per_sample figsize = (n_plots_y * 4, n_plots_x * 3) - fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize) + fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) pc = EquirectangularProjection() lat, lon = latlons[:, 0], latlons[:, 1] @@ -397,6 +401,7 @@ def plot_predicted_multilevel_flat_sample( variable_name, clevels, cmap_precip, + datashader, precip_and_related_fields, ) else: @@ -411,6 +416,7 @@ def plot_predicted_multilevel_flat_sample( variable_name, clevels, cmap_precip, + datashader, precip_and_related_fields, ) @@ -428,6 +434,7 @@ def plot_flat_sample( vname: str, clevels: float, cmap_precip: str, + datashader: bool = False, precip_and_related_fields: list | None = None, ) -> None: """Plot a "flat" 1D sample. @@ -436,7 +443,7 @@ def plot_flat_sample( Parameters ---------- - fig : _type_ + fig : Figure Figure object handle ax : matplotlib.axes Axis object handle @@ -456,9 +463,14 @@ def plot_flat_sample( Accumulation levels used for precipitation related plots cmap_precip: str Colors used for each accumulation level + datashader: bool, optional + Datashader plott, by default True precip_and_related_fields : list, optional List of precipitation-like variables, by default [] + Returns + ------- + None """ precip_and_related_fields = precip_and_related_fields or [] if vname in precip_and_related_fields: @@ -473,17 +485,38 @@ def plot_flat_sample( # converting to mm from m truth *= 1000.0 pred *= 1000.0 - scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, cmap=precip_colormap, norm=norm, title=f"{vname} target") - scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, cmap=precip_colormap, norm=norm, title=f"{vname} pred") - scatter_plot( + single_plot( + fig, + ax[1], + lon, + lat, + truth, + cmap=precip_colormap, + norm=norm, + title=f"{vname} target", + datashader=datashader, + ) + single_plot( + fig, + ax[2], + lon, + lat, + pred, + cmap=precip_colormap, + norm=norm, + title=f"{vname} pred", + datashader=datashader, + ) + single_plot( fig, ax[3], - lon=lon, - lat=lat, - data=truth - pred, + lon, + lat, + truth - pred, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} pred err", + datashader=datashader, ) elif vname == "mwd": cyclic_colormap = "twilight" @@ -495,10 +528,28 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: sample_shape = truth.shape pred = np.maximum(np.zeros(sample_shape), np.minimum(360 * np.ones(sample_shape), (pred))) - scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, cmap=cyclic_colormap, title=f"{vname} target") - scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, cmap=cyclic_colormap, title=f"capped {vname} pred") + single_plot( + fig, + ax[1], + lon=lon, + lat=lat, + data=truth, + cmap=cyclic_colormap, + title=f"{vname} target", + datashader=datashader, + ) + single_plot( + fig, + ax[2], + lon=lon, + lat=lat, + data=pred, + cmap=cyclic_colormap, + title=f"capped {vname} pred", + datashader=datashader, + ) err_plot = error_plot_in_degrees(truth, pred) - scatter_plot( + single_plot( fig, ax[3], lon=lon, @@ -507,26 +558,37 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} pred err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.", + datashader=datashader, ) else: - scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, title=f"{vname} target") - scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, title=f"{vname} pred") - scatter_plot( + single_plot(fig, ax[1], lon, lat, truth, title=f"{vname} target", datashader=datashader) + single_plot(fig, ax[2], lon, lat, pred, title=f"{vname} pred", datashader=datashader) + single_plot( fig, ax[3], - lon=lon, - lat=lat, - data=truth - pred, + lon, + lat, + truth - pred, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} pred err", + datashader=datashader, ) if sum(input_) != 0: if vname == "mwd": - scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, cmap=cyclic_colormap, title=f"{vname} input") + single_plot( + fig, + ax[0], + lon=lon, + lat=lat, + data=input_, + cmap=cyclic_colormap, + title=f"{vname} input", + datashader=datashader, + ) err_plot = error_plot_in_degrees(pred, input_) - scatter_plot( + single_plot( fig, ax[4], lon=lon, @@ -535,9 +597,10 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} increment [pred - input] % 360", + datashader=datashader, ) err_plot = error_plot_in_degrees(truth, input_) - scatter_plot( + single_plot( fig, ax[5], lon=lon, @@ -546,28 +609,31 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} persist err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.", + datashader=datashader, ) else: - scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, title=f"{vname} input") - scatter_plot( + single_plot(fig, ax[0], lon, lat, input_, title=f"{vname} input", datashader=datashader) + single_plot( fig, ax[4], - lon=lon, - lat=lat, - data=pred - input_, + lon, + lat, + pred - input_, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} increment [pred - input]", + datashader=datashader, ) - scatter_plot( + single_plot( fig, ax[5], - lon=lon, - lat=lat, - data=truth - input_, + lon, + lat, + truth - input_, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} persist err", + datashader=datashader, ) else: ax[0].axis("off") @@ -575,18 +641,21 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: ax[5].axis("off") -def scatter_plot( +def single_plot( fig: Figure, - ax: plt.Axes, - *, + ax: plt.axes, lon: np.array, lat: np.array, data: np.array, cmap: str = "viridis", norm: str | None = None, title: str | None = None, + datashader: bool = False, ) -> None: - """Lat-lon scatter plot: can work with arbitrary grids. + """Plot a single lat-lon map. + + Plotting can be made either using datashader plot or Datashader(bin) plots. + By default it uses Datashader since it is faster and more efficient. Parameters ---------- @@ -598,7 +667,7 @@ def scatter_plot( longitude coordinates array, shape (lon,) lat : np.ndarray latitude coordinates array, shape (lat,) - data : _type_ + data : np.ndarray Data to plot cmap : str, optional Colormap string from matplotlib, by default "viridis" @@ -606,18 +675,42 @@ def scatter_plot( Normalization string from matplotlib, by default None title : str, optional Title for plot, by default None + datashader: bool, optional + Scatter plot, by default False + Returns + ------- + None """ - psc = ax.scatter( - lon, - lat, - c=data, - cmap=cmap, - s=1, - alpha=1.0, - norm=norm, - rasterized=True, - ) + if not datashader: + psc = ax.scatter( + lon, + lat, + c=data, + cmap=cmap, + s=1, + alpha=1.0, + norm=norm, + rasterized=False, + ) + else: + df = pd.DataFrame({"val": data, "x": lon, "y": lat}) + # Adjust binning to match the resolution of the data + n_pixels = int(np.floor(data.shape[0] / 212)) + psc = dsshow( + df, + dsh.Point("x", "y"), + dsh.mean("val"), + vmin=data.min(), + vmax=data.max(), + cmap=cmap, + plot_width=n_pixels, + plot_height=n_pixels, + norm=norm, + aspect="auto", + ax=ax, + ) + ax.set_xlim((-np.pi, np.pi)) ax.set_ylim((-np.pi / 2, np.pi / 2)) @@ -644,9 +737,9 @@ def edge_plot( Parameters ---------- - fig : _type_ + fig : Figure Figure object handle - ax : _type_ + ax : matplotlib.axes Axis object handle src_coords : np.ndarray of shape (num_edges, 2) Source latitudes and longitudes. @@ -680,13 +773,15 @@ def edge_plot( fig.colorbar(psc, ax=ax) -def plot_graph_node_features(model: nn.Module) -> Figure: +def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figure: """Plot trainable graph node features. Parameters ---------- model: AneomiModelEncProcDec Model object + datashader: bool, optional + Scatter plot, by default False Returns ------- @@ -696,7 +791,7 @@ def plot_graph_node_features(model: nn.Module) -> Figure: nrows = len(nodes_name := model._graph_data.node_types) ncols = min(model.node_attributes.trainable_tensors[m].trainable.shape[1] for m in nodes_name) figsize = (ncols * 4, nrows * 3) - fig, ax = plt.subplots(nrows, ncols, figsize=figsize) + fig, ax = plt.subplots(nrows, ncols, figsize=figsize, layout=LAYOUT) for row, (mesh, trainable_tensor) in enumerate(model.node_attributes.trainable_tensors.items()): latlons = model.node_attributes.get_coordinates(mesh).cpu().numpy() @@ -706,13 +801,14 @@ def plot_graph_node_features(model: nn.Module) -> Figure: for i in range(ncols): ax_ = ax[row, i] if ncols > 1 else ax[row] - scatter_plot( + single_plot( fig, ax_, lon=lon, lat=lat, data=node_features[..., i], title=f"{mesh} trainable feature #{i + 1}", + datashader=datashader, ) return fig @@ -744,7 +840,7 @@ def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> ncols = min(module.trainable.trainable.shape[1] for module in trainable_modules.values()) nrows = len(trainable_modules) figsize = (ncols * 4, nrows * 3) - fig, ax = plt.subplots(nrows, ncols, figsize=figsize) + fig, ax = plt.subplots(nrows, ncols, figsize=figsize, layout=LAYOUT) for row, ((src, dst), graph_mapper) in enumerate(trainable_modules.items()): src_coords = model.node_attributes.get_coordinates(src).cpu().numpy() From 23fd04508a3f4e60de02b2a5c11ee10609dcff52 Mon Sep 17 00:00:00 2001 From: Gert Mertes <13658335+gmertes@users.noreply.github.com> Date: Mon, 18 Nov 2024 11:43:40 +0000 Subject: [PATCH 03/16] Increase MLflow HTTP retry (#111) * feat: increase mlflow http retry and timeout * chore: changelog * feat: reduce values * feat: update values for 1 hour of downtime * refactor: set env vars before mlflow import * chore: formatting * Update CHANGELOG.md --- CHANGELOG.md | 3 +++ src/anemoi/training/diagnostics/logger.py | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cdb21ca1..6a97a97c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ Keep it human-readable, your future self will thank you! ### Changed ## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14 +### Changed +- Increase the default MlFlow HTTP max retries [#111](https://github.com/ecmwf/anemoi-training/pull/111) + ### Fixed - Rename loss_scaling to variable_loss_scaling [#138](https://github.com/ecmwf/anemoi-training/pull/138) diff --git a/src/anemoi/training/diagnostics/logger.py b/src/anemoi/training/diagnostics/logger.py index 698c7c50..8f3ee729 100644 --- a/src/anemoi/training/diagnostics/logger.py +++ b/src/anemoi/training/diagnostics/logger.py @@ -10,6 +10,7 @@ from __future__ import annotations import logging +import os from pathlib import Path from typing import TYPE_CHECKING @@ -27,6 +28,15 @@ def get_mlflow_logger(config: DictConfig) -> None: LOGGER.debug("MLFlow logging is disabled.") return None + # 35 retries allow for 1 hour of server downtime + http_max_retries = config.diagnostics.log.mlflow.get("http_max_retries", 35) + + os.environ["MLFLOW_HTTP_REQUEST_MAX_RETRIES"] = str(http_max_retries) + os.environ["_MLFLOW_HTTP_REQUEST_MAX_RETRIES_LIMIT"] = str(http_max_retries + 1) + # these are the default values, but set them explicitly in case they change + os.environ["MLFLOW_HTTP_REQUEST_BACKOFF_FACTOR"] = "2" + os.environ["MLFLOW_HTTP_REQUEST_BACKOFF_JITTER"] = "1" + from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger resumed = config.training.run_id is not None From 7f4023d42116e95cc123d752b3507f07592ac7e4 Mon Sep 17 00:00:00 2001 From: Sara Hahner <44293258+sahahner@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:15:25 +0100 Subject: [PATCH 04/16] Rollout video of variable dynamics (#65) Rollout video of variable dynamics --- Review: @lzampier @HCookie @mc4117 @mchantry --- CHANGELOG.md | 1 + .../config/diagnostics/plot/rollout_eval.yaml | 2 + .../training/diagnostics/callbacks/plot.py | 282 +++++++++++++++--- src/anemoi/training/diagnostics/plots.py | 47 ++- 4 files changed, 279 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a97a97c..85ec45b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ Keep it human-readable, your future self will thank you! - Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63) - Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92) - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) +- Feat: Save a gif for longer rollouts in validation [#65](https://github.com/ecmwf/anemoi-training/pull/65) - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) diff --git a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml index 642e6e6b..5eece654 100644 --- a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml +++ b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml @@ -60,8 +60,10 @@ callbacks: - 10u - 10v - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots + # for rollout and video_rollout pick any integers below dataloader.validation_rollout rollout: - ${dataloader.validation_rollout} + video_rollout: ${dataloader.validation_rollout} every_n_epochs: 20 sample_idx: ${diagnostics.plot.sample_idx} parameters: ${diagnostics.plot.parameters} diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 171eb840..b13d8727 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -24,6 +24,7 @@ from pathlib import Path from typing import TYPE_CHECKING +import matplotlib.animation as animation import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np @@ -31,6 +32,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_only +from anemoi.training.diagnostics.plots import get_scatter_frame from anemoi.training.diagnostics.plots import init_plot_settings from anemoi.training.diagnostics.plots import plot_graph_edge_features from anemoi.training.diagnostics.plots import plot_graph_node_features @@ -119,6 +121,35 @@ def _output_figure( plt.close(fig) # cleanup + @rank_zero_only + def _output_gif( + self, + logger: pl.loggers.base.LightningLoggerBase, + fig: plt.Figure, + anim: animation.ArtistAnimation, + epoch: int, + tag: str = "gnn", + ) -> None: + """Animation output: save to file and/or display in notebook.""" + if self.save_basedir is not None: + save_path = Path( + self.save_basedir, + "plots", + f"{tag}_epoch{epoch:03d}.gif", + ) + + save_path.parent.mkdir(parents=True, exist_ok=True) + anim.save(save_path, writer="pillow", fps=8) + + if self.config.diagnostics.log.wandb.enabled: + LOGGER.warning("Saving gif animations not tested for wandb.") + + if self.config.diagnostics.log.mlflow.enabled: + run_id = logger.run_id + logger.experiment.log_artifact(run_id, str(save_path)) + + plt.close(fig) # cleanup + @rank_zero_only def _plot_with_error_catching(self, trainer: pl.Trainer, args: Any, kwargs: Any) -> None: """To execute the plot function but ensuring we catch any errors.""" @@ -261,7 +292,27 @@ def on_validation_epoch_end( class LongRolloutPlots(BasePlotCallback): - """Evaluates the model performance over a (longer) rollout window.""" + """Evaluates the model performance over a (longer) rollout window. + + This function allows evaluating the performance of the model over an extended number + of rollout steps to observe long-term behavior. + Add the callback to the configuration file as follows: + ``` + - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots + rollout: + - ${dataloader.validation_rollout} + video_rollout: ${dataloader.validation_rollout} + every_n_epochs: 1 + sample_idx: ${diagnostics.plot.sample_idx} + parameters: ${diagnostics.plot.parameters} + ``` + The selected rollout steps for plots and video need to be lower or equal to dataloader.validation_rollout. + Increasing dataloader.validation_rollout has no effect on the rollout steps during training. + It ensures, that enough time steps are available for the plots and video in the validation batches. + + The runtime of creating one animation of one variable for 56 rollout steps is about 1 minute. + Recommended use for video generation: Fork the run using fork_run_id for 1 additional epochs and enabled videos. + """ def __init__( self, @@ -269,10 +320,12 @@ def __init__( rollout: list[int], sample_idx: int, parameters: list[str], + video_rollout: int = 0, accumulation_levels_plot: list[float] | None = None, cmap_accumulation: list[str] | None = None, per_sample: int = 6, every_n_epochs: int = 1, + animation_interval: int = 400, ) -> None: """Initialise LongRolloutPlots callback. @@ -286,6 +339,8 @@ def __init__( Sample to plot parameters : list[str] Parameters to plot + video_rollout : int, optional + Number of rollout steps for video, by default 0 (no video) accumulation_levels_plot : list[float] | None Accumulation levels to plot, by default None cmap_accumulation : list[str] | None @@ -294,22 +349,39 @@ def __init__( Number of plots per sample, by default 6 every_n_epochs : int, optional Epoch frequency to plot at, by default 1 + animation_interval : int, optional + Delay between frames in the animation in milliseconds, by default 400 """ super().__init__(config) self.every_n_epochs = every_n_epochs - LOGGER.debug( - "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", - rollout, - every_n_epochs, - ) self.rollout = rollout + self.video_rollout = video_rollout + self.max_rollout = 0 + if self.rollout: + self.max_rollout = max(self.rollout) + else: + self.rollout = [] + if self.video_rollout: + self.max_rollout = max(self.max_rollout, self.video_rollout) + self.sample_idx = sample_idx self.accumulation_levels_plot = accumulation_levels_plot self.cmap_accumulation = cmap_accumulation self.per_sample = per_sample self.parameters = parameters + self.animation_interval = animation_interval + + LOGGER.info( + ( + "Setting up callback for plots with long rollout: rollout for plots = %s, ", + "rollout for video = %s, frequency = every %d epoch.", + ), + self.rollout, + self.video_rollout, + every_n_epochs, + ) @rank_zero_only def _plot( @@ -322,12 +394,10 @@ def _plot( epoch: int, ) -> None: _ = output - start_time = time.time() - logger = trainer.logger - # Build dictionary of inidicies and parameters to be plotted + # Initialize required variables for plotting plot_parameters_dict = { pl_module.data_indices.model.output.name_to_index[name]: ( name, @@ -335,15 +405,12 @@ def _plot( ) for name in self.parameters } - if self.post_processors is None: - # Copy to be used across all the training cycle self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() if self.latlons is None: self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - assert batch.shape[1] >= max(self.rollout) + pl_module.multi_step, ( + assert batch.shape[1] >= self.max_rollout + pl_module.multi_step, ( "Batch length not sufficient for requested validation rollout length! " f"Set `dataloader.validation_rollout` to at least {max(self.rollout)}" ) @@ -358,54 +425,175 @@ def _plot( ].cpu() data_0 = self.post_processors(input_tensor_0).numpy() - # start rollout + if self.video_rollout: + data_over_time = [] + # collect min and max values for each variable for the colorbar + vmin, vmax = (np.inf * np.ones(len(plot_parameters_dict)), -np.inf * np.ones(len(plot_parameters_dict))) + + # Plot for each rollout step# Plot for each rollout step with torch.no_grad(): for rollout_step, (_, _, y_pred) in enumerate( pl_module.rollout_step( batch, - rollout=max(self.rollout), + rollout=self.max_rollout, validation_mode=False, training_mode=False, ), ): - + # plot only if the current rollout step is in the list of rollout steps if (rollout_step + 1) in self.rollout: - # prepare true output tensor for plotting - input_tensor_rollout_step = input_batch[ - self.sample_idx, - pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy() - - # prepare predicted output tensor for plotting - output_tensor = self.post_processors( - y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu(), - ).numpy() - - fig = plot_predicted_multilevel_flat_sample( + self._plot_rollout_step( + pl_module, plot_parameters_dict, - self.per_sample, - self.latlons, - self.accumulation_levels_plot, - self.cmap_accumulation, - data_0.squeeze(), - data_rollout_step.squeeze(), - output_tensor[0, 0, :, :], # rolloutstep, first member + input_batch, + data_0, + rollout_step, + y_pred, + batch_idx, + epoch, + logger, ) - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_sample_rstep{rollout_step + 1:03d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_sample_rstep{rollout_step + 1:03d}_rank{local_rank:01d}", + if self.video_rollout and rollout_step < self.video_rollout: + data_over_time, vmin, vmax = self._store_video_frame_data( + data_over_time, + y_pred, + plot_parameters_dict, + vmin, + vmax, ) - LOGGER.info( - "Time taken to plot samples after longer rollout: %s seconds", - int(time.time() - start_time), + + # Generate and save video rollout animation if enabled + if self.video_rollout: + self._generate_video_rollout( + data_0, + data_over_time, + plot_parameters_dict, + vmin, + vmax, + self.video_rollout, + batch_idx, + epoch, + logger, + animation_interval=self.animation_interval, + ) + + LOGGER.info("Time taken to plot/animate samples for longer rollout: %d seconds", int(time.time() - start_time)) + + def _plot_rollout_step( + self, + pl_module: pl.LightningModule, + plot_parameters_dict: dict, + input_batch: torch.Tensor, + data_0: np.ndarray, + rollout_step: int, + y_pred: torch.Tensor, + batch_idx: int, + epoch: int, + logger: pl.loggers.base.LightningLoggerBase, + ) -> None: + """Plot the predicted output, input, true target and error plots for a given rollout step.""" + # prepare true output tensor for plotting + input_tensor_rollout_step = input_batch[ + self.sample_idx, + pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy() + # predicted output tensor + output_tensor = self.post_processors(y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu()).numpy() + + fig = plot_predicted_multilevel_flat_sample( + plot_parameters_dict, + self.per_sample, + self.latlons, + self.accumulation_levels_plot, + self.cmap_accumulation, + data_0.squeeze(), + data_rollout_step.squeeze(), + output_tensor[0, 0, :, :], # rolloutstep, first member ) + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_sample_rstep{rollout_step + 1:03d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_rstep{rollout_step + 1:03d}_rank{pl_module.local_rank:01d}", + ) + + def _store_video_frame_data( + self, + data_over_time: list, + y_pred: torch.Tensor, + plot_parameters_dict: dict, + vmin: np.ndarray, + vmax: np.ndarray, + ) -> tuple[list, np.ndarray, np.ndarray]: + """Store the data for each frame of the video.""" + # prepare predicted output tensors for video + output_tensor = self.post_processors(y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu()).numpy() + data_over_time.append(output_tensor[0, 0, :, np.array(list(plot_parameters_dict.keys()))]) + # update min and max values for each variable for the colorbar + vmin[:] = np.minimum(vmin, np.nanmin(data_over_time[-1], axis=1)) + vmax[:] = np.maximum(vmax, np.nanmax(data_over_time[-1], axis=1)) + return data_over_time, vmin, vmax + + def _generate_video_rollout( + self, + data_0: np.ndarray, + data_over_time: list, + plot_parameters_dict: dict, + vmin: np.ndarray, + vmax: np.ndarray, + rollout_step: int, + batch_idx: int, + epoch: int, + logger: pl.loggers.base.LightningLoggerBase, + animation_interval: int = 400, + ) -> None: + """Generate the video animation for the rollout.""" + for idx, (variable_idx, (variable_name, _)) in enumerate(plot_parameters_dict.items()): + # Create the animation and list to store the frames (artists) + frames = [] + # Prepare the figure + fig, ax = plt.subplots(figsize=(10, 6), dpi=72) + cmap = "twilight" if variable_name == "mwd" else "viridis" + + # Create initial data and colorbar + ax, scatter_frame = get_scatter_frame( + ax, + data_0[0, :, variable_idx], + self.latlons, + cmap=cmap, + vmin=vmin[idx], + vmax=vmax[idx], + ) + ax.set_title(f"{variable_name}") + fig.colorbar(scatter_frame, ax=ax) + frames.append([scatter_frame]) + + # Loop through the data and create the scatter plot for each frame + for frame_data in data_over_time: + ax, scatter_frame = get_scatter_frame( + ax, + frame_data[idx], + self.latlons, + cmap=cmap, + vmin=vmin[idx], + vmax=vmax[idx], + ) + frames.append([scatter_frame]) # Each frame contains a list of artists (images) + + # Create the animation using ArtistAnimation + anim = animation.ArtistAnimation(fig, frames, interval=animation_interval, blit=True) + self._output_gif( + logger, + fig, + anim, + epoch=epoch, + tag=f"gnn_pred_val_animation_{variable_name}_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", + ) @rank_zero_only def on_validation_batch_end( diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index d397f05c..e0b44d1c 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -21,6 +21,7 @@ from anemoi.models.layers.mapper import GraphEdgeMixin from datashader.mpl_ext import dsshow from matplotlib.collections import LineCollection +from matplotlib.collections import PathCollection from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap from matplotlib.colors import TwoSlopeNorm @@ -50,6 +51,13 @@ class LatLonData: data: np.ndarray +def equirectangular_projection(latlons: np.array) -> np.array: + pc = EquirectangularProjection() + lat, lon = latlons[:, 0], latlons[:, 1] + pc_lon, pc_lat = pc(lon, lat) + return pc_lat, pc_lon + + def init_plot_settings() -> None: """Initialize matplotlib plot settings.""" small_font_size = 8 @@ -159,9 +167,8 @@ def plot_power_spectrum( figsize = (n_plots_y * 4, n_plots_x * 3) fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) - pc = EquirectangularProjection() - lat, lon = latlons[:, 0], latlons[:, 1] - pc_lon, pc_lat = pc(lon, lat) + pc_lat, pc_lon = equirectangular_projection(latlons) + pc_lon = np.array(pc_lon) pc_lat = np.array(pc_lat) # Calculate delta_lon and delta_lat on the projected grid @@ -381,9 +388,7 @@ def plot_predicted_multilevel_flat_sample( figsize = (n_plots_y * 4, n_plots_x * 3) fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) - pc = EquirectangularProjection() - lat, lon = latlons[:, 0], latlons[:, 1] - pc_lon, pc_lat = pc(lon, lat) + pc_lat, pc_lon = equirectangular_projection(latlons) for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()): xt = x[..., variable_idx].squeeze() * int(output_only) @@ -724,6 +729,36 @@ def single_plot( fig.colorbar(psc, ax=ax) +def get_scatter_frame( + ax: plt.Axes, + data: np.ndarray, + latlons: np.ndarray, + cmap: str = "viridis", + vmin: int | None = None, + vmax: int | None = None, +) -> [plt.Axes, PathCollection]: + """Create a scatter plot for a single frame of an animation.""" + pc_lat, pc_lon = equirectangular_projection(latlons) + + scatter_frame = ax.scatter( + pc_lon, + pc_lat, + c=data, + cmap=cmap, + s=5, + alpha=1.0, + rasterized=True, + vmin=vmin, + vmax=vmax, + ) + ax.set_xlim((-np.pi, np.pi)) + ax.set_ylim((-np.pi / 2, np.pi / 2)) + continents.plot_continents(ax) + ax.set_aspect("auto", adjustable=None) + _hide_axes_ticks(ax) + return ax, scatter_frame + + def edge_plot( fig: Figure, ax: plt.Axes, From 923b266914c8d605507b7fda11917f2a686b81b3 Mon Sep 17 00:00:00 2001 From: Cathal O'Brien Date: Mon, 18 Nov 2024 17:47:47 +0100 Subject: [PATCH 05/16] dont check for tracking uri when running mlflow offline (#146) * no longer checking tracking uri when running mlflow offline old behaviour when running offline with the default tracking uri (???) would cause a crash. My understanding is there is no reason to check tracking uri when running offline because this is provided later when the run is synced to mlflow. One use case for this would be for external vendors benchmarking anemoi. They could use the mlflow logging features to analyse GPU utilization etc. without requiring a tracking uri. * Update src/anemoi/training/diagnostics/logger.py Co-authored-by: Harrison Cook * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Harrison Cook Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/anemoi/training/diagnostics/logger.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/diagnostics/logger.py b/src/anemoi/training/diagnostics/logger.py index 8f3ee729..2eb82113 100644 --- a/src/anemoi/training/diagnostics/logger.py +++ b/src/anemoi/training/diagnostics/logger.py @@ -43,8 +43,13 @@ def get_mlflow_logger(config: DictConfig) -> None: forked = config.training.fork_run_id is not None save_dir = config.hardware.paths.logs.mlflow - tracking_uri = config.diagnostics.log.mlflow.tracking_uri + offline = config.diagnostics.log.mlflow.offline + if not offline: + tracking_uri = config.diagnostics.log.mlflow.tracking_uri + LOGGER.info("AnemoiMLFlow logging to %s", tracking_uri) + else: + tracking_uri = None if (resumed or forked) and (offline): # when resuming or forking offline - # tracking_uri = ${hardware.paths.logs.mlflow} @@ -64,7 +69,6 @@ def get_mlflow_logger(config: DictConfig) -> None: ) log_hyperparams = False - LOGGER.info("AnemoiMLFlow logging to %s", tracking_uri) logger = AnemoiMLflowLogger( experiment_name=config.diagnostics.log.mlflow.experiment_name, project_name=config.diagnostics.log.mlflow.project_name, From 553f247d4de272b6814ef95cd8fc75eb49c352fb Mon Sep 17 00:00:00 2001 From: Cathal O'Brien Date: Thu, 21 Nov 2024 17:44:56 +0100 Subject: [PATCH 06/16] added red gpu & increased green gpu monitoring (#147) * added red gpu monitoring, increased green gpu monitoring and refactored monitors into their own files * applied feedback * added changelog entry * fixed Changelog entry --- CHANGELOG.md | 1 + .../training/diagnostics/mlflow/logger.py | 53 ++------ .../mlflow/system_metrics/cpu_monitor.py | 41 ++++++ .../mlflow/system_metrics/gpu_monitor.py | 127 ++++++++++++++++++ 4 files changed, 183 insertions(+), 39 deletions(-) create mode 100644 src/anemoi/training/diagnostics/mlflow/system_metrics/cpu_monitor.py create mode 100644 src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 85ec45b3..3abfe1b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ Keep it human-readable, your future self will thank you! - Feat: Save a gif for longer rollouts in validation [#65](https://github.com/ecmwf/anemoi-training/pull/65) - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) +- Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147) ### Changed diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 7c482ce1..71d5c475 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -433,56 +433,31 @@ def experiment(self) -> MLFlowLogger.experiment: def log_system_metrics(self) -> None: """Log system metrics (CPU, GPU, etc).""" import mlflow - import psutil - from mlflow.system_metrics.metrics.base_metrics_monitor import BaseMetricsMonitor from mlflow.system_metrics.metrics.disk_monitor import DiskMonitor - from mlflow.system_metrics.metrics.gpu_monitor import GPUMonitor from mlflow.system_metrics.metrics.network_monitor import NetworkMonitor from mlflow.system_metrics.system_metrics_monitor import SystemMetricsMonitor - class CustomCPUMonitor(BaseMetricsMonitor): - """Class for monitoring CPU stats. - - Extends default CPUMonitor, to also measure total \ - memory and a different formula for calculating used memory. - - """ - - def collect_metrics(self) -> None: - # Get CPU metrics. - cpu_percent = psutil.cpu_percent() - self._metrics["cpu_utilization_percentage"].append(cpu_percent) - - system_memory = psutil.virtual_memory() - # Change the formula for measuring CPU memory usage - # By default Mlflow uses psutil.virtual_memory().used - # Tests have shown that "used" underreports memory usage by as much as a factor of 2, - # "used" also misses increased memory usage from using a higher prefetch factor - self._metrics["system_memory_usage_megabytes"].append( - (system_memory.total - system_memory.available) / 1e6, - ) - self._metrics["system_memory_usage_percentage"].append(system_memory.percent) - - # QOL: report the total system memory in raw numbers - self._metrics["system_memory_total_megabytes"].append(system_memory.total / 1e6) - - def aggregate_metrics(self) -> dict[str, int]: - return {k: round(sum(v) / len(v), 1) for k, v in self._metrics.items()} + from anemoi.training.diagnostics.mlflow.system_metrics.cpu_monitor import CPUMonitor + from anemoi.training.diagnostics.mlflow.system_metrics.gpu_monitor import GreenGPUMonitor + from anemoi.training.diagnostics.mlflow.system_metrics.gpu_monitor import RedGPUMonitor class CustomSystemMetricsMonitor(SystemMetricsMonitor): def __init__(self, run_id: str, resume_logging: bool = False): super().__init__(run_id, resume_logging=resume_logging) - # Replace the CPUMonitor with custom implementation - self.monitors = [CustomCPUMonitor(), DiskMonitor(), NetworkMonitor()] + self.monitors = [CPUMonitor(), DiskMonitor(), NetworkMonitor()] + + # Try init both and catch the error when one init fails try: - gpu_monitor = GPUMonitor() + gpu_monitor = GreenGPUMonitor() self.monitors.append(gpu_monitor) - except ImportError: - LOGGER.warning( - "`pynvml` is not installed, to log GPU metrics please run `pip install pynvml` \ - to install it", - ) + except (ImportError, RuntimeError) as e: + LOGGER.warning("Failed to init Nvidia GPU Monitor: %s", e) + try: + gpu_monitor = RedGPUMonitor() + self.monitors.append(gpu_monitor) + except (ImportError, RuntimeError) as e: + LOGGER.warning("Failed to init AMD GPU Monitor: %s", e) mlflow.enable_system_metrics_logging() system_monitor = CustomSystemMetricsMonitor( diff --git a/src/anemoi/training/diagnostics/mlflow/system_metrics/cpu_monitor.py b/src/anemoi/training/diagnostics/mlflow/system_metrics/cpu_monitor.py new file mode 100644 index 00000000..fbf6b3e5 --- /dev/null +++ b/src/anemoi/training/diagnostics/mlflow/system_metrics/cpu_monitor.py @@ -0,0 +1,41 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import psutil +from mlflow.system_metrics.metrics.base_metrics_monitor import BaseMetricsMonitor + + +class CPUMonitor(BaseMetricsMonitor): + """Class for monitoring CPU stats. + + Extends default CPUMonitor, to also measure total \ + memory and a different formula for calculating used memory. + + """ + + def collect_metrics(self) -> None: + # Get CPU metrics. + cpu_percent = psutil.cpu_percent() + self._metrics["cpu_utilization_percentage"].append(cpu_percent) + + system_memory = psutil.virtual_memory() + # Change the formula for measuring CPU memory usage + # By default Mlflow uses psutil.virtual_memory().used + # Tests have shown that "used" underreports memory usage by as much as a factor of 2, + # "used" also misses increased memory usage from using a higher prefetch factor + self._metrics["system_memory_usage_megabytes"].append( + (system_memory.total - system_memory.available) / 1e6, + ) + self._metrics["system_memory_usage_percentage"].append(system_memory.percent) + + # QOL: report the total system memory in raw numbers + self._metrics["system_memory_total_megabytes"].append(system_memory.total / 1e6) + + def aggregate_metrics(self) -> dict[str, int]: + return {k: round(sum(v) / len(v), 1) for k, v in self._metrics.items()} diff --git a/src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py b/src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py new file mode 100644 index 00000000..b5a2c132 --- /dev/null +++ b/src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py @@ -0,0 +1,127 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import contextlib +import sys + +from mlflow.system_metrics.metrics.base_metrics_monitor import BaseMetricsMonitor + +with contextlib.suppress(ImportError): + import pynvml +with contextlib.suppress(ImportError): + from pyrsmi import rocml + + +class GreenGPUMonitor(BaseMetricsMonitor): + """Class for monitoring Nvidia GPU stats. + + Requires pynvml to be installed. + Extends default GPUMonitor, to also measure total \ + memory + + """ + + def __init__(self): + if "pynvml" not in sys.modules: + # Only instantiate if `pynvml` is installed. + import_error_msg = "`pynvml` is not installed, if you are running on an Nvidia GPU \ + and want to log GPU metrics please run `pip install pynvml`." + raise ImportError(import_error_msg) + try: + # `nvmlInit()` will fail if no GPU is found. + pynvml.nvmlInit() + except pynvml.NVMLError as e: + runtime_error_msg = "Failed to initalize Nvidia GPU monitor: " + raise RuntimeError(runtime_error_msg) from e + + super().__init__() + self.num_gpus = pynvml.nvmlDeviceGetCount() + self.gpu_handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(self.num_gpus)] + + def collect_metrics(self) -> None: + # Get GPU metrics. + for i, handle in enumerate(self.gpu_handles): + memory = pynvml.nvmlDeviceGetMemoryInfo(handle) + self._metrics[f"gpu_{i}_memory_usage_percentage"].append( + round(memory.used / memory.total * 100, 1), + ) + self._metrics[f"gpu_{i}_memory_usage_megabytes"].append(memory.used / 1e6) + + # Only record total device memory on GPU 0 to prevent spam + # Unlikely for GPUs on the same node to have different total memory + if i == 0: + self._metrics["gpu_memory_total_megabytes"].append(memory.total / 1e6) + + # Monitor PCIe usage + tx_kilobytes = pynvml.nvmlDeviceGetPcieThroughput(handle, pynvml.NVML_PCIE_UTIL_TX_BYTES) + rx_kilobytes = pynvml.nvmlDeviceGetPcieThroughput(handle, pynvml.NVML_PCIE_UTIL_RX_BYTES) + self._metrics[f"gpu_{i}_pcie_tx_megabytes"].append(tx_kilobytes / 1e3) + self._metrics[f"gpu_{i}_pcie_rx_megabytes"].append(rx_kilobytes / 1e3) + + device_utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) + self._metrics[f"gpu_{i}_utilization_percentage"].append(device_utilization.gpu) + + power_milliwatts = pynvml.nvmlDeviceGetPowerUsage(handle) + power_capacity_milliwatts = pynvml.nvmlDeviceGetEnforcedPowerLimit(handle) + self._metrics[f"gpu_{i}_power_usage_watts"].append(power_milliwatts / 1000) + self._metrics[f"gpu_{i}_power_usage_percentage"].append( + (power_milliwatts / power_capacity_milliwatts) * 100, + ) + + def aggregate_metrics(self) -> dict[str, int]: + return {k: round(sum(v) / len(v), 1) for k, v in self._metrics.items()} + + +class RedGPUMonitor(BaseMetricsMonitor): + """Class for monitoring AMD GPU stats. + + Requires that pyrsmi is installed + Logs utilization and memory usage. + + """ + + def __init__(self): + if "pyrsmi" not in sys.modules: + import_error_msg = "`pyrsmi` is not installed, if you are running on an AMD GPU \ + and want to log GPU metrics please run `pip install pyrsmi`." + # Only instantiate if `pyrsmi` is installed. + raise ImportError(import_error_msg) + try: + # `rocml.smi_initialize()()` will fail if no GPU is found. + rocml.smi_initialize() + except RuntimeError as e: + runtime_error_msg = "Failed to initalize AMD GPU monitor: " + raise RuntimeError(runtime_error_msg) from e + + super().__init__() + self.num_gpus = rocml.smi_get_device_count() + + def collect_metrics(self) -> None: + # Get GPU metrics. + for device in range(self.num_gpus): + memory_used = rocml.smi_get_device_memory_used(device) + memory_total = rocml.smi_get_device_memory_total(device) + memory_busy = rocml.smi_get_device_memory_busy(device) + self._metrics[f"gpu_{device}_memory_usage_percentage"].append( + round(memory_used / memory_total * 100, 1), + ) + self._metrics[f"gpu_{device}_memory_usage_megabytes"].append(memory_used / 1e6) + + self._metrics[f"gpu_{device}_memory_busy_percentage"].append(memory_busy) + + # Only record total device memory on GPU 0 to prevent spam + # Unlikely for GPUs on the same node to have different total memory + if device == 0: + self._metrics["gpu_memory_total_megabytes"].append(memory_total / 1e6) + + utilization = rocml.smi_get_device_utilization(device) + self._metrics[f"gpu_{device}_utilization_percentage"].append(utilization) + + def aggregate_metrics(self) -> dict[str, int]: + return {k: round(sum(v) / len(v), 1) for k, v in self._metrics.items()} From 25abf5e143a29d5931ccb4ac42a5f83c5cd26851 Mon Sep 17 00:00:00 2001 From: Sara Hahner <44293258+sahahner@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:24:17 +0100 Subject: [PATCH 07/16] Feature/mask NaNs in training loss function (#72) * feat: mask NaNs in training loss function --------- Co-authored-by: Jakob Schloer Co-authored-by: Harrison Cook --- CHANGELOG.md | 1 + .../training/config/training/default.yaml | 4 ++- src/anemoi/training/train/forecaster.py | 27 ++++++++++++++++++- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3abfe1b5..205b92d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -113,6 +113,7 @@ Keep it human-readable, your future self will thank you! - Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48) - Feature: `AnemoiMlflowClient`, an mlflow client with authentication support [#86](https://github.com/ecmwf/anemoi-training/pull/86) - Long Rollout Plots +- Mask NaN values in training loss function [#72](https://github.com/ecmwf/anemoi-training/pull/72) and [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271) ### Fixed diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index b471034e..1c103827 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -48,7 +48,9 @@ training_loss: # Scalars to include in loss calculation # Available scalars include: # - 'variable': See `variable_loss_scaling` for more information - scalars: ['variable'] + # - 'loss_weights_mask': Giving imputed NaNs a zero weight in the loss function + scalars: ['variable', 'loss_weights_mask'] + ignore_nans: False loss_gradient_scaling: False diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 80459b8f..5c9f5e84 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -103,7 +103,10 @@ def __init__( # Kwargs to pass to the loss function loss_kwargs = {"node_weights": self.node_weights} # Scalars to include in the loss function, must be of form (dim, scalar) - scalars = {"variable": (-1, variable_scaling)} + # Add mask multiplying NaN locations with zero. At this stage at [[1]]. + # Filled after first application of preprocessor. dimension=[-2, -1] (latlon, n_outputs). + scalars = {"variable": (-1, variable_scaling), "loss_weights_mask": ((-2, -1), torch.ones((1, 1)))} + self.updated_loss_mask = False self.loss = self.get_loss_function(config.training.training_loss, scalars=scalars, **loss_kwargs) @@ -217,6 +220,24 @@ def get_loss_function( return loss_function + def training_weights_for_imputed_variables( + self, + batch: torch.Tensor, + ) -> None: + """Update the loss weights mask for imputed variables.""" + if "loss_weights_mask" in self.loss.scalar: + loss_weights_mask = torch.ones((1, 1), device=batch.device) + # iterate over all pre-processors and check if they have a loss_mask_training attribute + for pre_processor in self.model.pre_processors.processors.values(): + if hasattr(pre_processor, "loss_mask_training"): + loss_weights_mask = loss_weights_mask * pre_processor.loss_mask_training + # if transform_loss_mask function exists for preprocessor apply it + if hasattr(pre_processor, "transform_loss_mask"): + loss_weights_mask = pre_processor.transform_loss_mask(loss_weights_mask) + # update scaler with loss_weights_mask retrieved from preprocessors + self.loss.update_scalar(scalar=loss_weights_mask.cpu(), name="loss_weights_mask") + self.updated_loss_mask = True + @staticmethod def get_val_metric_ranges(config: DictConfig, data_indices: IndexCollection) -> tuple[dict, dict]: @@ -361,6 +382,10 @@ def rollout_step( # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) + if not self.updated_loss_mask: + # update loss scalar after first application and initialization of preprocessors + self.training_weights_for_imputed_variables(batch) + # start rollout of preprocessed batch x = batch[ :, From 11ded7db6dfbb84f3654426aff62d98cc376b2f8 Mon Sep 17 00:00:00 2001 From: gabrieloks <116646686+gabrieloks@users.noreply.github.com> Date: Fri, 22 Nov 2024 13:48:50 +0000 Subject: [PATCH 08/16] [FIX] Power spectra bug on n320 (LAM?) (#149) * Update plots.py To plot the power spectra, we need to create a regular grid (n_pix_lat x n_pix_lon) and interpolate the data on it. The way n_pix_lat and n_pix_lon were previously defined is not robust and might lead to errors. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plots.py * Update CHANGELOG.md --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 2 ++ src/anemoi/training/diagnostics/plots.py | 8 +++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 205b92d5..953f2b8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD) ### Fixed +Fixed bug in power spectra plotting for the n320 resolution. + ### Added ### Changed diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index e0b44d1c..6d1d1c8c 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -171,15 +171,13 @@ def plot_power_spectrum( pc_lon = np.array(pc_lon) pc_lat = np.array(pc_lat) - # Calculate delta_lon and delta_lat on the projected grid - delta_lon = abs(np.diff(pc_lon)) - non_zero_delta_lon = delta_lon[delta_lon != 0] + # Calculate delta_lat on the projected grid delta_lat = abs(np.diff(pc_lat)) non_zero_delta_lat = delta_lat[delta_lat != 0] # Define a regular grid for interpolation - n_pix_lon = int(np.floor(abs(pc_lon.max() - pc_lon.min()) / abs(np.min(non_zero_delta_lon)))) # around 400 for O96 - n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / abs(np.min(non_zero_delta_lat)))) # around 192 for O96 + n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / abs(np.min(non_zero_delta_lat)))) + n_pix_lon = (n_pix_lat - 1) * 2 + 1 # 2*lmax + 1 regular_pc_lon = np.linspace(pc_lon.min(), pc_lon.max(), n_pix_lon) regular_pc_lat = np.linspace(pc_lat.min(), pc_lat.max(), n_pix_lat) grid_pc_lon, grid_pc_lat = np.meshgrid(regular_pc_lon, regular_pc_lat) From cf53a6e3eea81a789a2ffc3b9d867afc1ce132d6 Mon Sep 17 00:00:00 2001 From: Jan Polster Date: Fri, 22 Nov 2024 17:11:32 +0100 Subject: [PATCH 09/16] Feature/Improve Dataloader Memory with Read Groups (#76) * feat: improve dataloader memory - Add reader groups to support sharded reading of batches - Add dataloader.read_group_size in config to control read behaviour - Add GraphForecaster.allgather_batch() to reconstruct full batch from shards - Refactor callbacks to call allgather on batches as needed * docs: update docstring with instructions on reader group usage * refactor: rank computations via SLURM_PROCID - Pass model/reader group information from DDPGroupStrategy instead --------- Co-authored-by: gabrieloks <116646686+gabrieloks@users.noreply.github.com> Co-authored-by: Harrison Cook Co-authored-by: sahahner --- CHANGELOG.md | 2 + docs/user-guide/distributed.rst | 4 + .../config/dataloader/native_grid.yaml | 11 ++ src/anemoi/training/data/datamodule.py | 29 ---- src/anemoi/training/data/dataset.py | 82 +++++++++-- .../diagnostics/callbacks/evaluation.py | 4 +- .../training/diagnostics/callbacks/plot.py | 16 ++- src/anemoi/training/distributed/strategy.py | 130 +++++++++++++++--- src/anemoi/training/train/forecaster.py | 81 +++++++++-- src/anemoi/training/train/train.py | 4 +- 10 files changed, 283 insertions(+), 80 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 953f2b8a..88488369 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ Fixed bug in power spectra plotting for the n320 resolution. ### Added +- Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) + ### Changed ## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14 diff --git a/docs/user-guide/distributed.rst b/docs/user-guide/distributed.rst index 40ee4d65..68d7697a 100644 --- a/docs/user-guide/distributed.rst +++ b/docs/user-guide/distributed.rst @@ -45,6 +45,10 @@ number of GPUs you wish to shard the model across. It is recommended to only shard if the model does not fit in GPU memory, as data distribution is a much more efficient way to parallelise the training. +When using model sharding, ``config.dataloader.read_group_size`` allows +for sharded data loading in subgroups. This should be set to the number +of GPUs per model for optimal performance. + ********* Example ********* diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index d7aa4f6d..9513ecc7 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -1,6 +1,17 @@ prefetch_factor: 2 pin_memory: True +# ============ +# read_group_size: +# Form subgroups of model comm groups that read data together. +# Each reader in the group only reads 1/read_group_size of the data +# which is then all-gathered between the group. +# This can reduce CPU memory usage as well as increase dataloader throughput. +# The number of GPUs per model must be divisible by read_group_size. +# To disable, set to 1. +# ============ +read_group_size: ${hardware.num_gpus_per_model} + num_workers: training: 8 validation: 8 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 303266fc..6d8e6da0 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -9,7 +9,6 @@ import logging -import os from functools import cached_property from typing import Callable @@ -43,31 +42,6 @@ def __init__(self, config: DictConfig) -> None: self.config = config - self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank - self.model_comm_group_id = ( - self.global_rank // self.config.hardware.num_gpus_per_model - ) # id of the model communication group the rank is participating in - self.model_comm_group_rank = ( - self.global_rank % self.config.hardware.num_gpus_per_model - ) # rank within one model communication group - total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes - assert ( - total_gpus - ) % self.config.hardware.num_gpus_per_model == 0, ( - f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}" - ) - self.model_comm_num_groups = ( - self.config.hardware.num_gpus_per_node - * self.config.hardware.num_nodes - // self.config.hardware.num_gpus_per_model - ) # number of model communication groups - LOGGER.debug( - "Rank %d model communication group number %d, with local model communication group rank %d", - self.global_rank, - self.model_comm_group_id, - self.model_comm_group_rank, - ) - # Set the maximum rollout to be expected self.rollout = ( self.config.training.rollout.max @@ -182,9 +156,6 @@ def _get_dataset( rollout=r, multistep=self.config.training.multistep_input, timeincrement=self.timeincrement, - model_comm_group_rank=self.model_comm_group_rank, - model_comm_group_id=self.model_comm_group_id, - model_comm_num_groups=self.model_comm_num_groups, shuffle=shuffle, label=label, ) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 9e368f9c..40065e06 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -36,9 +36,6 @@ def __init__( rollout: int = 1, multistep: int = 1, timeincrement: int = 1, - model_comm_group_rank: int = 0, - model_comm_group_id: int = 0, - model_comm_num_groups: int = 1, shuffle: bool = True, label: str = "generic", ) -> None: @@ -54,12 +51,6 @@ def __init__( time increment between samples, by default 1 multistep : int, optional collate (t-1, ... t - multistep) into the input state vector, by default 1 - model_comm_group_rank : int, optional - process rank in the torch.distributed group (important when running on multiple GPUs), by default 0 - model_comm_group_id: int, optional - device group ID, default 0 - model_comm_num_groups : int, optional - total number of device groups, by default 1 shuffle : bool, optional Shuffle batches, by default True label : str, optional @@ -77,11 +68,14 @@ def __init__( self.n_samples_per_epoch_total: int = 0 self.n_samples_per_epoch_per_worker: int = 0 - # DDP-relevant info - self.model_comm_group_rank = model_comm_group_rank - self.model_comm_num_groups = model_comm_num_groups - self.model_comm_group_id = model_comm_group_id - self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) + # lazy init model and reader group info, will be set by the DDPGroupStrategy: + self.model_comm_group_rank = 0 + self.model_comm_num_groups = 1 + self.model_comm_group_id = 0 + self.global_rank = 0 + + self.reader_group_rank = 0 + self.reader_group_size = 1 # additional state vars (lazy init) self.n_samples_per_worker = 0 @@ -93,6 +87,8 @@ def __init__( assert self.multi_step > 0, "Multistep value must be greater than zero." self.ensemble_dim: int = 2 self.ensemble_size = self.data.shape[self.ensemble_dim] + self.grid_dim: int = -1 + self.grid_size = self.data.shape[self.grid_dim] @cached_property def statistics(self) -> dict: @@ -128,6 +124,58 @@ def valid_date_indices(self) -> np.ndarray: """ return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) + def set_comm_group_info( + self, + global_rank: int, + model_comm_group_id: int, + model_comm_group_rank: int, + model_comm_num_groups: int, + reader_group_rank: int, + reader_group_size: int, + ) -> None: + """Set model and reader communication group information (called by DDPGroupStrategy). + + Parameters + ---------- + global_rank : int + Global rank + model_comm_group_id : int + Model communication group ID + model_comm_group_rank : int + Model communication group rank + model_comm_num_groups : int + Number of model communication groups + reader_group_rank : int + Reader group rank + reader_group_size : int + Reader group size + """ + self.global_rank = global_rank + self.model_comm_group_id = model_comm_group_id + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.reader_group_rank = reader_group_rank + self.reader_group_size = reader_group_size + + if self.reader_group_size > 1: + # get the grid shard size and start/end indices + grid_shard_size = self.grid_size // self.reader_group_size + self.grid_start = self.reader_group_rank * grid_shard_size + if self.reader_group_rank == self.reader_group_size - 1: + self.grid_end = self.grid_size + else: + self.grid_end = (self.reader_group_rank + 1) * grid_shard_size + + LOGGER.debug( + "NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, " + "model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d", + global_rank, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + reader_group_rank, + ) + def per_worker_init(self, n_workers: int, worker_id: int) -> None: """Called by worker_init_func on each copy of dataset. @@ -233,7 +281,11 @@ def __iter__(self) -> torch.Tensor: start = i - (self.multi_step - 1) * self.timeincrement end = i + (self.rollout + 1) * self.timeincrement - x = self.data[start : end : self.timeincrement] + if self.reader_group_size > 1: # read only a subset of the grid + x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end] + else: # read the full grid + x = self.data[start : end : self.timeincrement, :, :, :] + x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") self.ensemble_dim = 1 diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index fc812121..cbc929d6 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -15,7 +15,6 @@ import torch from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_only if TYPE_CHECKING: import pytorch_lightning as pl @@ -103,7 +102,6 @@ def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, rank_zero_only=True, ) - @rank_zero_only def on_validation_batch_end( self, trainer: pl.Trainer, @@ -114,6 +112,8 @@ def on_validation_batch_end( ) -> None: del outputs # outputs are not used if batch_idx % self.every_n_batches == 0: + batch = pl_module.allgather_batch(batch) + precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index b13d8727..08f9d28b 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -241,7 +241,6 @@ def __init__(self, config: OmegaConf, every_n_batches: int | None = None): super().__init__(config) self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch - @rank_zero_only def on_validation_batch_end( self, trainer: pl.Trainer, @@ -251,7 +250,16 @@ def on_validation_batch_end( batch_idx: int, **kwargs, ) -> None: + if ( + self.config.diagnostics.plot.asynchronous + and self.config.dataloader.read_group_size > 1 + and pl_module.local_rank == 0 + ): + LOGGER.warning("Asynchronous plotting can result in NCCL timeouts with reader_group_size > 1.") + if batch_idx % self.every_n_batches == 0: + batch = pl_module.allgather_batch(batch) + self.plot( trainer, pl_module, @@ -383,7 +391,6 @@ def __init__( every_n_epochs, ) - @rank_zero_only def _plot( self, trainer: pl.Trainer, @@ -480,6 +487,7 @@ def _plot( LOGGER.info("Time taken to plot/animate samples for longer rollout: %d seconds", int(time.time() - start_time)) + @rank_zero_only def _plot_rollout_step( self, pl_module: pl.LightningModule, @@ -539,6 +547,7 @@ def _store_video_frame_data( vmax[:] = np.maximum(vmax, np.nanmax(data_over_time[-1], axis=1)) return data_over_time, vmin, vmax + @rank_zero_only def _generate_video_rollout( self, data_0: np.ndarray, @@ -595,7 +604,6 @@ def _generate_video_rollout( tag=f"gnn_pred_val_animation_{variable_name}_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", ) - @rank_zero_only def on_validation_batch_end( self, trainer: pl.Trainer, @@ -605,6 +613,8 @@ def on_validation_batch_end( batch_idx: int, ) -> None: if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.every_n_epochs == 0: + batch = pl_module.allgather_batch(batch) + precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, diff --git a/src/anemoi/training/distributed/strategy.py b/src/anemoi/training/distributed/strategy.py index c6509795..32c96dc6 100644 --- a/src/anemoi/training/distributed/strategy.py +++ b/src/anemoi/training/distributed/strategy.py @@ -9,7 +9,6 @@ import logging -import os import numpy as np import pytorch_lightning as pl @@ -27,19 +26,22 @@ class DDPGroupStrategy(DDPStrategy): """Distributed Data Parallel strategy with group communication.""" - def __init__(self, num_gpus_per_model: int, **kwargs: dict) -> None: + def __init__(self, num_gpus_per_model: int, read_group_size: int, **kwargs: dict) -> None: """Initialize the distributed strategy. Parameters ---------- num_gpus_per_model : int Number of GPUs per model to shard over. + read_group_size : int + Number of GPUs per reader group. **kwargs : dict Additional keyword arguments. """ super().__init__(**kwargs) self.model_comm_group_size = num_gpus_per_model + self.read_group_size = read_group_size def setup(self, trainer: pl.Trainer) -> None: assert self.accelerator is not None, "Accelerator is not initialized for distributed strategy" @@ -60,18 +62,56 @@ def setup(self, trainer: pl.Trainer) -> None: torch.distributed.new_group(x) for x in model_comm_group_ranks ] # every rank has to create all of these - model_comm_group_id, model_comm_group_nr, model_comm_group_rank = self.get_my_model_comm_group( + model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( self.model_comm_group_size, ) model_comm_group = model_comm_groups[model_comm_group_id] - self.model.set_model_comm_group(model_comm_group) + self.model.set_model_comm_group( + model_comm_group, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + self.model_comm_group_size, + ) + + # set up reader groups by further splitting model_comm_group_ranks with read_group_size: + + assert self.model_comm_group_size % self.read_group_size == 0, ( + f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by read_group_size " + f"({self.read_group_size})." + ) + + reader_group_ranks = np.array( + [ + np.split(group_ranks, int(self.model_comm_group_size / self.read_group_size)) + for group_ranks in model_comm_group_ranks + ], + ) # Shape: (num_model_comm_groups, model_comm_grp_size/read_group_size, read_group_size) + reader_groups = [[torch.distributed.new_group(x) for x in group_ranks] for group_ranks in reader_group_ranks] + reader_group_id, reader_group_rank, reader_group_size, reader_group_root = self.get_my_reader_group( + model_comm_group_rank, + self.read_group_size, + ) + # get all reader groups of the current model group + model_reader_groups = reader_groups[model_comm_group_id] + self.model.set_reader_groups( + model_reader_groups, + reader_group_id, + reader_group_rank, + reader_group_size, + ) + LOGGER.debug( - "Rank %d model_comm_group is %s, group number %d, with local group rank %d and comms_group_ranks %s", + "Rank %d model_comm_group_id: %d model_comm_group: %s model_comm_group_rank: %d " + "reader_group_id: %d reader_group: %s reader_group_rank: %d reader_group_root (global): %d", self.global_rank, - str(model_comm_group_nr), model_comm_group_id, - model_comm_group_rank, str(model_comm_group_ranks[model_comm_group_id]), + model_comm_group_rank, + reader_group_id, + reader_group_ranks[model_comm_group_id, reader_group_id], + reader_group_rank, + reader_group_root, ) # register hooks for correct gradient reduction @@ -109,7 +149,7 @@ def setup(self, trainer: pl.Trainer) -> None: # seed ranks self.seed_rnd(model_comm_group_id) - def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndarray, int]: + def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, int, int]: """Determine tasks that work together and from a model group. Parameters @@ -119,19 +159,69 @@ def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndar Returns ------- - tuple[int, np.ndarray, int] - Model_comm_group id, Model_comm_group Nr, Model_comm_group rank + tuple[int, int, int] + Model_comm_group id, Model_comm_group rank, Number of model_comm_groups + """ + model_comm_group_id = self.global_rank // num_gpus_per_model + model_comm_group_rank = self.global_rank % num_gpus_per_model + model_comm_num_groups = self.world_size // num_gpus_per_model + + return model_comm_group_id, model_comm_group_rank, model_comm_num_groups + + def get_my_reader_group(self, model_comm_group_rank: int, read_group_size: int) -> tuple[int, int, int]: + """Determine tasks that work together and from a reader group. + + Parameters + ---------- + model_comm_group_rank : int + Rank within the model communication group. + read_group_size : int + Number of dataloader readers per model group. + + Returns + ------- + tuple[int, int, int] + Reader_group id, Reader_group rank, Reader_group root (global rank) """ - model_comm_groups = np.arange(0, self.world_size, dtype=np.int32) - model_comm_groups = np.split(model_comm_groups, self.world_size / num_gpus_per_model) + reader_group_id = model_comm_group_rank // read_group_size + reader_group_rank = model_comm_group_rank % read_group_size + reader_group_size = read_group_size + reader_group_root = (self.global_rank // read_group_size) * read_group_size + + return reader_group_id, reader_group_rank, reader_group_size, reader_group_root - model_comm_group_id = None - for i, model_comm_group in enumerate(model_comm_groups): - if self.global_rank in model_comm_group: - model_comm_group_id = i - model_comm_group_nr = model_comm_group - model_comm_group_rank = np.ravel(np.asarray(model_comm_group == self.global_rank).nonzero())[0] - return model_comm_group_id, model_comm_group_nr, model_comm_group_rank + def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> torch.utils.data.DataLoader: + """Pass communication group information to the dataloader for distributed training. + + Parameters + ---------- + dataloader : torch.utils.data.DataLoader + Dataloader to process. + + Returns + ------- + torch.utils.data.DataLoader + Processed dataloader. + + """ + dataloader = super().process_dataloader(dataloader) + + # pass model and reader group information to the dataloaders dataset + model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( + self.model_comm_group_size, + ) + _, reader_group_rank, _, _ = self.get_my_reader_group(model_comm_group_rank, self.read_group_size) + + dataloader.dataset.set_comm_group_info( + self.global_rank, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + reader_group_rank, + self.read_group_size, + ) + + return dataloader def seed_rnd(self, model_comm_group_id: int) -> None: """Seed the random number generators for the rank.""" @@ -145,7 +235,7 @@ def seed_rnd(self, model_comm_group_id: int) -> None: "Strategy: Rank %d, model comm group id %d, base seed %d, seeded with %d, " "running with random seed: %d, sanity rnd: %s" ), - int(os.environ.get("SLURM_PROCID", "0")), + self.global_rank, model_comm_group_id, base_seed, initial_seed, diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 5c9f5e84..659c906c 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -9,8 +9,6 @@ import logging -import math -import os from collections import defaultdict from collections.abc import Generator from collections.abc import Mapping @@ -138,17 +136,20 @@ def __init__( self.use_zero_optimizer = config.training.zero_optimizer self.model_comm_group = None + self.reader_groups = None LOGGER.debug("Rollout window length: %d", self.rollout) LOGGER.debug("Rollout increase every : %d epochs", self.rollout_epoch_increment) LOGGER.debug("Rollout max : %d", self.rollout_max) LOGGER.debug("Multistep: %d", self.multi_step) - self.model_comm_group_id = int(os.environ.get("SLURM_PROCID", "0")) // config.hardware.num_gpus_per_model - self.model_comm_group_rank = int(os.environ.get("SLURM_PROCID", "0")) % config.hardware.num_gpus_per_model - self.model_comm_num_groups = math.ceil( - config.hardware.num_gpus_per_node * config.hardware.num_nodes / config.hardware.num_gpus_per_model, - ) + # lazy init model and reader group info, will be set by the DDPGroupStrategy: + self.model_comm_group_id = 0 + self.model_comm_group_rank = 0 + self.model_comm_num_groups = 1 + + self.reader_group_id = 0 + self.reader_group_rank = 0 def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x, self.model_comm_group) @@ -313,9 +314,31 @@ def get_variable_scaling( return torch.from_numpy(variable_loss_scaling) - def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: - LOGGER.debug("set_model_comm_group: %s", model_comm_group) + def set_model_comm_group( + self, + model_comm_group: ProcessGroup, + model_comm_group_id: int, + model_comm_group_rank: int, + model_comm_num_groups: int, + model_comm_group_size: int, + ) -> None: self.model_comm_group = model_comm_group + self.model_comm_group_id = model_comm_group_id + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.model_comm_group_size = model_comm_group_size + + def set_reader_groups( + self, + reader_groups: list[ProcessGroup], + reader_group_id: int, + reader_group_rank: int, + reader_group_size: int, + ) -> None: + self.reader_groups = reader_groups + self.reader_group_id = reader_group_id + self.reader_group_rank = reader_group_rank + self.reader_group_size = reader_group_size def advance_input( self, @@ -425,6 +448,8 @@ def _step( validation_mode: bool = False, ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx + batch = self.allgather_batch(batch) + loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) metrics = {} y_preds = [] @@ -442,6 +467,44 @@ def _step( loss *= 1.0 / self.rollout return loss, metrics, y_preds + def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: + """Allgather the batch-shards across the reader group. + + Parameters + ---------- + batch : torch.Tensor + Batch-shard of current reader rank + + Returns + ------- + torch.Tensor + Allgathered (full) batch + """ + grid_size = self.model.metadata["dataset"]["shape"][-1] + + if grid_size == batch.shape[-2]: + return batch # already have the full grid + + grid_shard_size = grid_size // self.reader_group_size + last_grid_shard_size = grid_size - (grid_shard_size * (self.reader_group_size - 1)) + + # prepare tensor list with correct shapes for all_gather + shard_shape = list(batch.shape) + shard_shape[-2] = grid_shard_size + last_shard_shape = list(batch.shape) + last_shard_shape[-2] = last_grid_shard_size + + tensor_list = [torch.empty(tuple(shard_shape), device=self.device) for _ in range(self.reader_group_size - 1)] + tensor_list.append(torch.empty(last_shard_shape, device=self.device)) + + torch.distributed.all_gather( + tensor_list, + batch, + group=self.reader_groups[self.reader_group_id], + ) + + return torch.cat(tensor_list, dim=-2) + def calculate_val_metrics( self, y_pred: torch.Tensor, diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 553114f5..80fc70d3 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -12,7 +12,6 @@ import datetime import logging -import os from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING @@ -106,7 +105,7 @@ def initial_seed(self) -> int: (torch.rand(1), np_rng.random()) LOGGER.debug( "Initial seed: Rank %d, initial seed %d, running with random seed: %d", - int(os.environ.get("SLURM_PROCID", "0")), + self.strategy.global_rank, initial_seed, rnd_seed, ) @@ -345,6 +344,7 @@ def strategy(self) -> DDPGroupStrategy: """Training strategy.""" return DDPGroupStrategy( self.config.hardware.num_gpus_per_model, + self.config.dataloader.get("read_group_size", self.config.hardware.num_gpus_per_model), static_graph=not self.config.training.accum_grad_batches > 1, ) From e3fe023553230c57d6d1b40720188f9cd28d1f3b Mon Sep 17 00:00:00 2001 From: gabrieloks <116646686+gabrieloks@users.noreply.github.com> Date: Mon, 25 Nov 2024 11:35:37 +0000 Subject: [PATCH 10/16] Validation and Training dataset dates assertion (#154) * Change training end date/validation start date assertion to a warning to allow flexibility. Co-authored-by: Harrison Cook --------- Co-authored-by: Harrison Cook --- src/anemoi/training/data/datamodule.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 6d8e6da0..ba9ff0c3 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -116,10 +116,12 @@ def ds_train(self) -> NativeGridDataset: def ds_valid(self) -> NativeGridDataset: r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1)) - assert self.config.dataloader.training.end < self.config.dataloader.validation.start, ( - f"Training end date {self.config.dataloader.training.end} is not before" - f"validation start date {self.config.dataloader.validation.start}" - ) + if not self.config.dataloader.training.end < self.config.dataloader.validation.start: + LOGGER.warning( + "Training end date %s is not before validation start date %s.", + self.config.dataloader.training.end, + self.config.dataloader.validation.start, + ) return self._get_dataset( open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)), shuffle=False, From 0608f21abb78d425d965cd79ea040aaf3a66b5f8 Mon Sep 17 00:00:00 2001 From: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:25:17 +0100 Subject: [PATCH 11/16] warmup config for reproducibility of aifs v0.3 (#155) * warmup config for reproducibility of aifs v0.3 * add entry to changelog * update docs --- CHANGELOG.md | 2 ++ docs/user-guide/training.rst | 15 ++++++++++----- src/anemoi/training/config/training/default.yaml | 1 + src/anemoi/training/train/forecaster.py | 3 ++- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 88488369..f50ee7f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ Keep it human-readable, your future self will thank you! Fixed bug in power spectra plotting for the n320 resolution. ### Added +- Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155) + - Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) diff --git a/docs/user-guide/training.rst b/docs/user-guide/training.rst index 5be08222..588b34d9 100644 --- a/docs/user-guide/training.rst +++ b/docs/user-guide/training.rst @@ -188,10 +188,11 @@ level has a weighting less than 0.2). *************** Anemoi training uses the ``CosineLRScheduler`` from PyTorch as it's -learning rate scheduler. The user can configure the maximum learning -rate by setting ``config.training.lr.rate``. Note that this learning -rate is scaled by the number of GPUs where for the `data parallelism -`_. +learning rate scheduler. Docs for this scheduler can be found here +https://github.com/huggingface/pytorch-image-models/blob/main/timm/scheduler/cosine_lr.py +The user can configure the maximum learning rate by setting +``config.training.lr.rate``. Note that this learning rate is scaled by +the number of GPUs where for the `data parallelism `_. .. code:: yaml @@ -201,7 +202,11 @@ The user can also control the rate at which the learning rate decreases by setting the total number of iterations through ``config.training.lr.iterations`` and the minimum learning rate reached through ``config.training.lr.min``. Note that the minimum learning rate -is not scaled by the number of GPUs. +is not scaled by the number of GPUs. The user can also control the +warmup period by setting ``config.training.lr.warmup_t``. If the warmup +period is set to 0, the learning rate will start at the maximum learning +rate. If no warmup period is defined, a default warmup period of 1000 +iterations is used. ********* Rollout diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 1c103827..af168ecc 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -83,6 +83,7 @@ lr: rate: 0.625e-4 #local_lr iterations: ${training.max_steps} # NOTE: When max_epochs < max_steps, scheduler will run for max_steps min: 3e-7 #Not scaled by #GPU + warmup_t: 1000 # Changes in per-gpu batch_size should come with a rescaling of the local_lr # in order to keep a constant global_lr diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 659c906c..a3abd59c 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -127,6 +127,7 @@ def __init__( * config.training.lr.rate / config.hardware.num_gpus_per_model ) + self.warmup_t = getattr(config.training.lr, "warmup_t", 1000) self.lr_iterations = config.training.lr.iterations self.lr_min = config.training.lr.min self.rollout = config.training.rollout.start @@ -638,6 +639,6 @@ def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]] optimizer, lr_min=self.lr_min, t_initial=self.lr_iterations, - warmup_t=1000, + warmup_t=self.warmup_t, ) return [optimizer], [{"scheduler": scheduler, "interval": "step"}] From 1abb65ea50a4fd00d85c60175af791c4f5f48b98 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 25 Nov 2024 17:05:07 +0000 Subject: [PATCH 12/16] hotfix: Expand scalar to prevent index out of bound error (#160) * Disable scalar indices if no variable scalar is used in val_metrics --- src/anemoi/training/losses/weightedloss.py | 2 ++ src/anemoi/training/train/forecaster.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anemoi/training/losses/weightedloss.py b/src/anemoi/training/losses/weightedloss.py index 0deccc9d..7ed97b21 100644 --- a/src/anemoi/training/losses/weightedloss.py +++ b/src/anemoi/training/losses/weightedloss.py @@ -107,6 +107,8 @@ def scale( if scalar_indices is None: return x * scalar + + scalar = scalar.expand_as(x) return x * scalar[scalar_indices] def scale_by_node_weights(self, x: torch.Tensor, squash: bool = True) -> torch.Tensor: diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index a3abd59c..f92050cf 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -547,7 +547,7 @@ def calculate_val_metrics( metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( y_pred_postprocessed[..., indices], y_postprocessed[..., indices], - scalar_indices=[..., indices], + scalar_indices=[..., indices] if -1 in metric.scalar else None, ) return metrics From fa430782331f9825eaf986cc5e80ed7a0dd83364 Mon Sep 17 00:00:00 2001 From: Sara Hahner <44293258+sahahner@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:05:39 +0100 Subject: [PATCH 13/16] Callback PlotHistogram breaks if only one variable is specified. (#165) * enable plothistogram and plotspectrum for only one variable --- CHANGELOG.md | 3 ++- src/anemoi/training/diagnostics/plots.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f50ee7f0..80a4bcca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,8 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD) ### Fixed -Fixed bug in power spectra plotting for the n320 resolution. +- Fixed bug in power spectra plotting for the n320 resolution. +- Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165) ### Added - Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155) diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 6d1d1c8c..93e2d324 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -166,6 +166,8 @@ def plot_power_spectrum( figsize = (n_plots_y * 4, n_plots_x * 3) fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) + if n_plots_x == 1: + ax = [ax] pc_lat, pc_lon = equirectangular_projection(latlons) @@ -293,6 +295,8 @@ def plot_histogram( figsize = (n_plots_y * 4, n_plots_x * 3) fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) + if n_plots_x == 1: + ax = [ax] for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()): yt = y_true[..., variable_idx].squeeze() From 112d78f8fff7ef95e88e0aa2735857bf31c7d912 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 27 Nov 2024 10:16:29 +0000 Subject: [PATCH 14/16] Remove excess mlflow params (#169) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Remove ‘metadata.dataset.variables_metadata’ from parms to log --- src/anemoi/training/diagnostics/mlflow/logger.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 71d5c475..78a80be9 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -501,7 +501,16 @@ def _clean_params(params: dict[str, Any]) -> dict[str, Any]: dict[str, Any] Cleaned up params ready for MlFlow. """ - prefixes_to_remove = ["hardware", "data", "dataloader", "model", "training", "diagnostics", "metadata.config"] + prefixes_to_remove = [ + "hardware", + "data", + "dataloader", + "model", + "training", + "diagnostics", + "metadata.config", + "metadata.dataset.variables_metadata", + ] keys_to_remove = [key for key in params if any(key.startswith(prefix) for prefix in prefixes_to_remove)] for key in keys_to_remove: del params[key] From ac38f928648b02cc837d457d093a7e27a11fd3fd Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz <48736305+JPXKQX@users.noreply.github.com> Date: Thu, 28 Nov 2024 09:39:41 +0100 Subject: [PATCH 15/16] Upgrade configs to anemoi-graphs 0.4.1 (#159) * chore: update configs to anemoi-graphs=0.4.1 * feat: bump anemoi-graphs version requirement to >= 0.4.1 * fix: target_mask_attr_name inside edge builder * fix: remove from default * Update CHANGELOG.md * Update pyproject.toml * Update CHANGELOG.md * Update CHANGELOG.md * fix: lam plotting * fix: cast to DotDict * fix: add import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: anemoi-graphs 0.4.1 format --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 +-- pyproject.toml | 4 ++-- .../config/graph/encoder_decoder_only.yaml | 10 +++++----- .../training/config/graph/limited_area.yaml | 14 +++++++------- .../training/config/graph/multi_scale.yaml | 16 ++++++++-------- .../training/config/graph/stretched_grid.yaml | 12 ++++++------ .../training/diagnostics/callbacks/plot.py | 4 ++-- src/anemoi/training/train/train.py | 4 +++- 8 files changed, 34 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 80a4bcca..26b71522 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,9 +16,8 @@ Keep it human-readable, your future self will thank you! ### Added - Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155) - - - Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) +- Bump `anemoi-graphs` version to 0.4.1 [#159](https://github.com/ecmwf/anemoi-training/pull/159) ### Changed ## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14 diff --git a/pyproject.toml b/pyproject.toml index 8d685acd..10d8efda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,8 +40,8 @@ classifiers = [ dynamic = [ "version" ] dependencies = [ - "anemoi-datasets>=0.4", - "anemoi-graphs>=0.4", + "anemoi-datasets>=0.5.2", + "anemoi-graphs>=0.4.1", "anemoi-models>=0.3", "anemoi-utils[provenance]>=0.4.4", "datashader>=0.16.3", diff --git a/src/anemoi/training/config/graph/encoder_decoder_only.yaml b/src/anemoi/training/config/graph/encoder_decoder_only.yaml index b813254d..76907d82 100644 --- a/src/anemoi/training/config/graph/encoder_decoder_only.yaml +++ b/src/anemoi/training/config/graph/encoder_decoder_only.yaml @@ -22,15 +22,15 @@ edges: # Encoder configuration - source_name: ${graph.data} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges cutoff_factor: 0.6 # only for cutoff method attributes: ${graph.attributes.edges} -- source_name: ${graph.hidden} # Decoder configuration +- source_name: ${graph.hidden} target_name: ${graph.data} - edge_builder: - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges num_nearest_neighbours: 3 # only for knn method attributes: ${graph.attributes.edges} diff --git a/src/anemoi/training/config/graph/limited_area.yaml b/src/anemoi/training/config/graph/limited_area.yaml index f17bc384..a22405b6 100644 --- a/src/anemoi/training/config/graph/limited_area.yaml +++ b/src/anemoi/training/config/graph/limited_area.yaml @@ -23,23 +23,23 @@ edges: # Encoder configuration - source_name: ${graph.data} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges cutoff_factor: 0.6 # only for cutoff method attributes: ${graph.attributes.edges} # Processor configuration - source_name: ${graph.hidden} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.MultiScaleEdges + edge_builders: + - _target_: anemoi.graphs.edges.MultiScaleEdges x_hops: 1 attributes: ${graph.attributes.edges} # Decoder configuration - source_name: ${graph.hidden} target_name: ${graph.data} - target_mask_attr_name: cutout - edge_builder: - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges + target_mask_attr_name: cutout num_nearest_neighbours: 3 # only for knn method attributes: ${graph.attributes.edges} diff --git a/src/anemoi/training/config/graph/multi_scale.yaml b/src/anemoi/training/config/graph/multi_scale.yaml index 7e54535e..eec38d82 100644 --- a/src/anemoi/training/config/graph/multi_scale.yaml +++ b/src/anemoi/training/config/graph/multi_scale.yaml @@ -22,22 +22,22 @@ edges: # Encoder configuration - source_name: ${graph.data} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges cutoff_factor: 0.6 # only for cutoff method attributes: ${graph.attributes.edges} -- source_name: ${graph.hidden} # Processor configuration +- source_name: ${graph.hidden} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.MultiScaleEdges + edge_builders: + - _target_: anemoi.graphs.edges.MultiScaleEdges x_hops: 1 attributes: ${graph.attributes.edges} -- source_name: ${graph.hidden} # Decoder configuration +- source_name: ${graph.hidden} target_name: ${graph.data} - edge_builder: - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges + edge_builders: + - _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges num_nearest_neighbours: 3 # only for knn method attributes: ${graph.attributes.edges} diff --git a/src/anemoi/training/config/graph/stretched_grid.yaml b/src/anemoi/training/config/graph/stretched_grid.yaml index dad0172d..a92f319b 100644 --- a/src/anemoi/training/config/graph/stretched_grid.yaml +++ b/src/anemoi/training/config/graph/stretched_grid.yaml @@ -34,22 +34,22 @@ edges: # Encoder - source_name: ${graph.data} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.KNNEdges + edge_builders: + - _target_: anemoi.graphs.edges.KNNEdges num_nearest_neighbours: 12 attributes: ${graph.attributes.edges} # Processor - source_name: ${graph.hidden} target_name: ${graph.hidden} - edge_builder: - _target_: anemoi.graphs.edges.MultiScaleEdges + edge_builders: + - _target_: anemoi.graphs.edges.MultiScaleEdges x_hops: 1 attributes: ${graph.attributes.edges} # Decoder - source_name: ${graph.hidden} target_name: ${graph.data} - edge_builder: - _target_: anemoi.graphs.edges.KNNEdges + edge_builders: + - _target_: anemoi.graphs.edges.KNNEdges num_nearest_neighbours: 3 attributes: ${graph.attributes.edges} diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 08f9d28b..a54bab70 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -938,7 +938,7 @@ def _plot( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, ) - output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + output_tensor = pl_module.output_mask.apply(output_tensor, dim=1, fill_value=np.nan).numpy() data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) data = data.numpy() @@ -999,7 +999,7 @@ def process( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, ) - output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + output_tensor = pl_module.output_mask.apply(output_tensor, dim=1, fill_value=np.nan).numpy() data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) data = data.numpy() return data, output_tensor diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 80fc70d3..a18ed4dc 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -20,6 +20,7 @@ import numpy as np import pytorch_lightning as pl import torch +from anemoi.utils.config import DotDict from anemoi.utils.provenance import gather_provenance_info from omegaconf import DictConfig from omegaconf import OmegaConf @@ -128,7 +129,8 @@ def graph_data(self) -> HeteroData: from anemoi.graphs.create import GraphCreator - return GraphCreator(config=self.config.graph).create( + graph_config = DotDict(OmegaConf.to_container(self.config.graph, resolve=True)) + return GraphCreator(config=graph_config).create( save_path=graph_filename, overwrite=self.config.graph.overwrite, ) From 2809a81e7d2a2843ce1098a771cd536b5237aad5 Mon Sep 17 00:00:00 2001 From: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com> Date: Thu, 28 Nov 2024 09:50:33 +0100 Subject: [PATCH 16/16] Change number of pixels used by datashader (#152) * change hardcoded value of n_pixels - Tested with LAM & Global --- CHANGELOG.md | 2 ++ src/anemoi/training/diagnostics/plots.py | 27 +++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26b71522..4305b870 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD) ### Fixed +- Update `n_pixel` used by datashader to better adapt across resolutions #152 + - Fixed bug in power spectra plotting for the n320 resolution. - Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165) diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 93e2d324..45818b69 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -24,6 +24,7 @@ from matplotlib.collections import PathCollection from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap +from matplotlib.colors import Normalize from matplotlib.colors import TwoSlopeNorm from pyshtools.expand import SHGLQ from pyshtools.expand import SHExpandGLQ @@ -568,8 +569,12 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: datashader=datashader, ) else: - single_plot(fig, ax[1], lon, lat, truth, title=f"{vname} target", datashader=datashader) - single_plot(fig, ax[2], lon, lat, pred, title=f"{vname} pred", datashader=datashader) + combined_data = np.concatenate((input_, truth, pred)) + # For 'errors', only persistence and increments need identical colorbar-limits + combined_error = np.concatenate(((pred - input_), (truth - input_))) + norm = Normalize(vmin=np.nanmin(combined_data), vmax=np.nanmax(combined_data)) + single_plot(fig, ax[1], lon, lat, truth, norm=norm, title=f"{vname} target", datashader=datashader) + single_plot(fig, ax[2], lon, lat, pred, norm=norm, title=f"{vname} pred", datashader=datashader) single_plot( fig, ax[3], @@ -619,7 +624,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: datashader=datashader, ) else: - single_plot(fig, ax[0], lon, lat, input_, title=f"{vname} input", datashader=datashader) + single_plot(fig, ax[0], lon, lat, input_, norm=norm, title=f"{vname} input", datashader=datashader) single_plot( fig, ax[4], @@ -627,7 +632,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: lat, pred - input_, cmap="bwr", - norm=TwoSlopeNorm(vcenter=0.0), + norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()), title=f"{vname} increment [pred - input]", datashader=datashader, ) @@ -638,7 +643,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: lat, truth - input_, cmap="bwr", - norm=TwoSlopeNorm(vcenter=0.0), + norm=TwoSlopeNorm(vmin=combined_error.min(), vcenter=0.0, vmax=combined_error.max()), title=f"{vname} persist err", datashader=datashader, ) @@ -703,13 +708,13 @@ def single_plot( else: df = pd.DataFrame({"val": data, "x": lon, "y": lat}) # Adjust binning to match the resolution of the data - n_pixels = int(np.floor(data.shape[0] / 212)) + lower_limit = 25 + upper_limit = 500 + n_pixels = max(min(int(np.floor(data.shape[0] * 0.004)), upper_limit), lower_limit) psc = dsshow( df, dsh.Point("x", "y"), dsh.mean("val"), - vmin=data.min(), - vmax=data.max(), cmap=cmap, plot_width=n_pixels, plot_height=n_pixels, @@ -718,8 +723,10 @@ def single_plot( ax=ax, ) - ax.set_xlim((-np.pi, np.pi)) - ax.set_ylim((-np.pi / 2, np.pi / 2)) + xmin, xmax = max(lon.min(), -np.pi), min(lon.max(), np.pi) + ymin, ymax = max(lat.min(), -np.pi / 2), min(lat.max(), np.pi / 2) + ax.set_xlim((xmin - 0.1, xmax + 0.1)) + ax.set_ylim((ymin - 0.1, ymax + 0.1)) continents.plot_continents(ax)