diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e861091e..b5d7431b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: # Run CI including downstream packages on self-hosted runners downstream-ci: name: downstream-ci - if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }} + if: ${{ !contains(github.repository, 'private') && (!github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci') }} uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci.yml@main with: anemoi-training: ecmwf/anemoi-training@${{ github.event.pull_request.head.sha || github.sha }} @@ -45,7 +45,7 @@ jobs: # Build downstream packages on HPC downstream-ci-hpc: name: downstream-ci-hpc - if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }} + if: ${{ !contains(github.repository, 'private') && (!github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci') }} uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci-hpc.yml@main with: anemoi-training: ecmwf/anemoi-training@${{ github.event.pull_request.head.sha || github.sha }} diff --git a/.github/workflows/push-to-private.yml b/.github/workflows/push-to-private.yml new file mode 100644 index 00000000..4cc53efd --- /dev/null +++ b/.github/workflows/push-to-private.yml @@ -0,0 +1,33 @@ +name: Push to private repository + +on: + push: + branches: + - develop + +jobs: + push_changes: + if: ${{ !contains(github.repository, 'private') }} + runs-on: ubuntu-latest + + steps: + - name: Checkout source repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + fetch-tags: true + + - name: Set up Git configuration + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Setup SSH key + uses: webfactory/ssh-agent@v0.5.0 + with: + ssh-private-key: ${{ secrets.KEY_TO_PRIVATE }} + + - name: Push changes to private repository + run: | + git remote add private git@github.com:${{ github.repository }}-private.git + git push --set-upstream private develop diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 1db47a1b..6bb65533 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -10,11 +10,13 @@ on: jobs: quality: + if: ${{ !contains(github.repository, 'private') }} uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-precommit-run.yml@v2 with: skip-hooks: "no-commit-to-branch" checks: + if: ${{ !contains(github.repository, 'private') }} strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12"] @@ -23,6 +25,7 @@ jobs: python-version: ${{ matrix.python-version }} deploy: + if: ${{ !contains(github.repository, 'private') }} needs: [checks, quality] uses: ecmwf-actions/reusable-workflows/.github/workflows/cd-pypi.yml@v2 secrets: inherit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e01d6a37..bbc225df 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - id: python-check-blanket-noqa # Check for # noqa: all - id: python-no-log-warn # Check for log.warn - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black args: [--line-length=120] @@ -40,16 +40,15 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.2 hooks: - id: ruff - # Next line if for documenation cod snippets - exclude: '^[^_].*_\.py$' args: - --line-length=120 - --fix - --exit-non-zero-on-fix - --preview + - --exclude=docs/**/*_.py - repo: https://github.com/sphinx-contrib/sphinx-lint rev: v1.0.0 hooks: @@ -60,13 +59,8 @@ repos: hooks: - id: rstfmt exclude: 'cli/.*' # Because we use argparse -- repo: https://github.com/b8raoult/pre-commit-docconvert - rev: "0.1.5" - hooks: - - id: docconvert - args: ["numpy"] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.2.4" + rev: "v2.5.0" hooks: - id: pyproject-fmt - repo: https://github.com/jshwi/docsig # Check docstrings against function sig diff --git a/CHANGELOG.md b/CHANGELOG.md index bb550b02..21ef6ff3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,12 +10,40 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.2.2...HEAD) +### Fixed +- Rename loss_scaling to variable_loss_scaling [#138](https://github.com/ecmwf/anemoi-training/pull/138) +- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) + - Updated docs [#115](https://github.com/ecmwf/anemoi-training/pull/115) + - Fix enabling LearningRateMonitor [#119](https://github.com/ecmwf/anemoi-training/pull/119) +- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87) + - Enable longer validation rollout than training +- Expand iterables in logging [#91](https://github.com/ecmwf/anemoi-training/pull/91) + - Save entire config in mlflow +### Added +- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70) + - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) +- Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137) + - Add without subsetting in ScaleTensor +- Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63) +- Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92) +- Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) +- New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) +- New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) + +### Changed +- Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118) +- Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67) +- Merged node & edge trainable feature callbacks into one. [#135](https://github.com/ecmwf/anemoi-training/pull/135) + ## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28 + ### Changed - Lock python version <3.13 [#107](https://github.com/ecmwf/anemoi-training/pull/107) + + ## [0.2.1 - Bugfix: resuming mlflow runs](https://github.com/ecmwf/anemoi-training/compare/0.2.0...0.2.1) - 2024-10-24 ### Added @@ -27,6 +55,10 @@ Keep it human-readable, your future self will thank you! ### Fixed +- Fix pre-commit regex +- Mlflow-sync to handle creation of new experiments in the remote server [#83] (https://github.com/ecmwf/anemoi-training/pull/83) +- Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99] (https://github.com/ecmwf/anemoi-training/pull/99) +- ci: fix pyshtools install error (#100) https://github.com/ecmwf/anemoi-training/pull/100 - Mlflow-sync to handle creation of new experiments in the remote server [#83](https://github.com/ecmwf/anemoi-training/pull/83) - Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99](https://github.com/ecmwf/anemoi-training/pull/99) - ci: fix pyshtools install error [#100](https://github.com/ecmwf/anemoi-training/pull/100) @@ -51,6 +83,8 @@ Keep it human-readable, your future self will thank you! - Introduction of remapper to anemoi-models leads to changes in the data indices. Some preprocessors cannot be applied in-place anymore. +- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) + #### Functionality - Enable the callback for plotting a histogram for variables containing NaNs @@ -62,6 +96,7 @@ Keep it human-readable, your future self will thank you! - Feature: `AnemoiMlflowClient`, an mlflow client with authentication support [#86](https://github.com/ecmwf/anemoi-training/pull/86) - Long Rollout Plots + ### Fixed - Fix `TypeError` raised when trying to JSON serialise `datetime.timedelta` object - [#43](https://github.com/ecmwf/anemoi-training/pull/43) @@ -76,6 +111,7 @@ Keep it human-readable, your future self will thank you! - Updated configuration examples in documentation and corrected links - [#46](https://github.com/ecmwf/anemoi-training/pull/46) - Remove credential prompt from mlflow login, replace with seed refresh token via web - [#78](https://github.com/ecmwf/anemoi-training/pull/78) - Update CODEOWNERS +- Change how mlflow measures CPU Memory usage - [94](https://github.com/ecmwf/anemoi-training/pull/94) ## [0.1.0 - Anemoi training - First release](https://github.com/ecmwf/anemoi-training/releases/tag/0.1.0) - 2024-08-16 diff --git a/docs/conf.py b/docs/conf.py index 12e25dd5..dc7a24d9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,3 +1,12 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full @@ -13,10 +22,11 @@ import datetime import os import sys +from pathlib import Path read_the_docs_build = os.environ.get("READTHEDOCS", None) == "True" -sys.path.insert(0, os.path.join(os.path.abspath(".."), "src")) +sys.path.insert(0, Path("..").absolute() / "src") source_suffix = ".rst" @@ -30,13 +40,12 @@ project = "Anemoi Training" -author = "ECMWF" +author = "Anemoi contributors" -year = datetime.datetime.now().year +year = datetime.datetime.now(tz="UTC").year years = "2024" if year == 2024 else f"2024-{year}" -copyright = f"{years}, ECMWF" - +copyright = f"{years}, Anemoi contributors" # noqa: A001 try: from anemoi.training._version import __version__ @@ -64,7 +73,7 @@ ] # Add any paths that contain templates here, relative to this directory. -# templates_path = ["_templates"] +# templates_path = ["_templates"] # noqa: ERA001 # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. diff --git a/docs/images/profiler/anemoi_profiler_architecture.png b/docs/images/profiler/anemoi_profiler_architecture.png new file mode 100644 index 00000000..483571d1 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_architecture.png differ diff --git a/docs/images/profiler/anemoi_profiler_benchmark_profiler.png b/docs/images/profiler/anemoi_profiler_benchmark_profiler.png new file mode 100644 index 00000000..5cc6d7d1 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_benchmark_profiler.png differ diff --git a/docs/images/profiler/anemoi_profiler_config.png b/docs/images/profiler/anemoi_profiler_config.png new file mode 100644 index 00000000..dd98469b Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_config.png differ diff --git a/docs/images/profiler/anemoi_profiler_high_level.png b/docs/images/profiler/anemoi_profiler_high_level.png new file mode 100644 index 00000000..bd86c4fe Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_high_level.png differ diff --git a/docs/images/profiler/anemoi_profiler_mlflow_integration.png b/docs/images/profiler/anemoi_profiler_mlflow_integration.png new file mode 100644 index 00000000..cbd03d9f Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_mlflow_integration.png differ diff --git a/docs/images/profiler/anemoi_profiler_mlflow_integration_2.png b/docs/images/profiler/anemoi_profiler_mlflow_integration_2.png new file mode 100644 index 00000000..196b8818 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_mlflow_integration_2.png differ diff --git a/docs/images/profiler/anemoi_profiler_mlflow_integration_3.png b/docs/images/profiler/anemoi_profiler_mlflow_integration_3.png new file mode 100644 index 00000000..d4897502 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_mlflow_integration_3.png differ diff --git a/docs/images/profiler/anemoi_profiler_speed_report.png b/docs/images/profiler/anemoi_profiler_speed_report.png new file mode 100644 index 00000000..dbec34e4 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_speed_report.png differ diff --git a/docs/images/profiler/anemoi_profiler_speedreport_diagram.png b/docs/images/profiler/anemoi_profiler_speedreport_diagram.png new file mode 100644 index 00000000..a69324e4 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_speedreport_diagram.png differ diff --git a/docs/images/profiler/anemoi_profiler_training_rates.png b/docs/images/profiler/anemoi_profiler_training_rates.png new file mode 100644 index 00000000..e26da246 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_training_rates.png differ diff --git a/docs/images/profiler/anemoi_profiler_validation_rates.png b/docs/images/profiler/anemoi_profiler_validation_rates.png new file mode 100644 index 00000000..aa352cde Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_validation_rates.png differ diff --git a/docs/images/profiler/example_memory_report.png b/docs/images/profiler/example_memory_report.png new file mode 100644 index 00000000..0f42ebd0 Binary files /dev/null and b/docs/images/profiler/example_memory_report.png differ diff --git a/docs/images/profiler/example_memory_timeline.png b/docs/images/profiler/example_memory_timeline.png new file mode 100644 index 00000000..93591893 Binary files /dev/null and b/docs/images/profiler/example_memory_timeline.png differ diff --git a/docs/images/profiler/example_model_summary.png b/docs/images/profiler/example_model_summary.png new file mode 100644 index 00000000..498eff30 Binary files /dev/null and b/docs/images/profiler/example_model_summary.png differ diff --git a/docs/images/profiler/example_model_summary_2.png b/docs/images/profiler/example_model_summary_2.png new file mode 100644 index 00000000..c8adc538 Binary files /dev/null and b/docs/images/profiler/example_model_summary_2.png differ diff --git a/docs/images/profiler/example_system_report.png b/docs/images/profiler/example_system_report.png new file mode 100644 index 00000000..f6f002fa Binary files /dev/null and b/docs/images/profiler/example_system_report.png differ diff --git a/docs/images/profiler/example_time_report.png b/docs/images/profiler/example_time_report.png new file mode 100644 index 00000000..b8918a33 Binary files /dev/null and b/docs/images/profiler/example_time_report.png differ diff --git a/docs/images/profiler/idle_time_breakdown.png b/docs/images/profiler/idle_time_breakdown.png new file mode 100644 index 00000000..e183b010 Binary files /dev/null and b/docs/images/profiler/idle_time_breakdown.png differ diff --git a/docs/images/profiler/kernel_breakdown_dfs.png b/docs/images/profiler/kernel_breakdown_dfs.png new file mode 100644 index 00000000..20aee8c7 Binary files /dev/null and b/docs/images/profiler/kernel_breakdown_dfs.png differ diff --git a/docs/images/profiler/kernel_breakdown_plots.png b/docs/images/profiler/kernel_breakdown_plots.png new file mode 100644 index 00000000..e36d59a4 Binary files /dev/null and b/docs/images/profiler/kernel_breakdown_plots.png differ diff --git a/docs/images/profiler/memory_snapshot_diagram.png b/docs/images/profiler/memory_snapshot_diagram.png new file mode 100644 index 00000000..87ca6669 Binary files /dev/null and b/docs/images/profiler/memory_snapshot_diagram.png differ diff --git a/docs/images/profiler/memory_snapshot_output.png b/docs/images/profiler/memory_snapshot_output.png new file mode 100644 index 00000000..b22b9f4f Binary files /dev/null and b/docs/images/profiler/memory_snapshot_output.png differ diff --git a/docs/images/profiler/temporal_breakdown.png b/docs/images/profiler/temporal_breakdown.png new file mode 100644 index 00000000..a3a0370e Binary files /dev/null and b/docs/images/profiler/temporal_breakdown.png differ diff --git a/docs/index.rst b/docs/index.rst index bfbd2dbf..f5b36758 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -43,6 +43,7 @@ This package provides the *Anemoi* training functionality. user-guide/training user-guide/models user-guide/tracking + user-guide/benchmarking user-guide/distributed user-guide/debugging diff --git a/docs/modules/diagnostics.rst b/docs/modules/diagnostics.rst index c46f201c..28eac7c7 100644 --- a/docs/modules/diagnostics.rst +++ b/docs/modules/diagnostics.rst @@ -21,23 +21,94 @@ functionality to use both Weights & Biases and Tensorboard. The callbacks can also be used to evaluate forecasts over longer rollouts beyond the forecast time that the model is trained on. The -number of rollout steps (or forecast iteration steps) is set using -``config.eval.rollout = *num_of_rollout_steps*``. - -Note the user has the option to evaluate the callbacks asynchronously -(using the following config option -``config.diagnostics.plot.asynchronous``, which means that the model -training doesn't stop whilst the callbacks are being evaluated). -However, note that callbacks can still be slow, and therefore the -plotting callbacks can be switched off by setting -``config.diagnostics.plot.enabled`` to ``False`` or all the callbacks -can be completely switched off by setting -``config.diagnostics.eval.enabled`` to ``False``. +number of rollout steps for verification (or forecast iteration steps) +is set using ``config.dataloader.validation_rollout = +*num_of_rollout_steps*``. + +Callbacks are configured in the config file under the +``config.diagnostics`` key. + +For regular callbacks, they can be provided as a list of dictionaries +underneath the ``config.diagnostics.callbacks`` key. Each dictionary +must have a ``_target`` key which is used by hydra to instantiate the +callback, any other kwarg is passed to the callback's constructor. + +.. code:: yaml + + callbacks: + - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval + rollout: ${dataloader.validation_rollout} + frequency: 20 + +Plotting callbacks are configured in a similar way, but they are +specified underneath the ``config.diagnostics.plot.callbacks`` key. + +This is done to ensure seperation and ease of configuration between +experiments. + +``config.diagnostics.plot`` is a broader config file specifying the +parameters to plot, as well as the plotting frequency, and +asynchronosity. + +Setting ``config.diagnostics.plot.asynchronous``, means that the model +training doesn't stop whilst the callbacks are being evaluated) + +.. code:: yaml + + plot: + asynchronous: True # Whether to plot asynchronously + frequency: # Frequency of the plotting + batch: 750 + epoch: 5 + + # Parameters to plot + parameters: + - z_500 + - t_850 + - u_850 + + # Sample index + sample_idx: 0 + + # Precipitation and related fields + precip_and_related_fields: [tp, cp] + + callbacks: + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: ${diagnostics.plot.sample_idx} + per_sample : 6 + parameters: ${diagnostics.plot.parameters} Below is the documentation for the default callbacks provided, but it is also possible for users to add callbacks using the same structure: -.. automodule:: anemoi.training.diagnostics.callbacks +.. automodule:: anemoi.training.diagnostics.callbacks.checkpoint + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.evaluation + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.optimiser + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.plot + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.provenance :members: :no-undoc-members: :show-inheritance: diff --git a/docs/modules/losses.rst b/docs/modules/losses.rst index 73045d62..32ad9783 100644 --- a/docs/modules/losses.rst +++ b/docs/modules/losses.rst @@ -3,22 +3,153 @@ ######## This module is used to define the loss function used to train the model. + +Anemoi-training exposes a couple of loss functions by default to be +used, all of which are subclassed from ``BaseWeightedLoss``. This class +enables scalar multiplication, and graph node weighting. + +.. automodule:: anemoi.training.losses.weightedloss + :members: + :no-undoc-members: + :show-inheritance: + +************************ + Default Loss Functions +************************ + By default anemoi-training trains the model using a latitude-weighted mean-squared-error, which is defined in the ``WeightedMSELoss`` class in -``aifs/losses/mse.py``. +``anemoi/training/losses/mse.py``. The loss function can be configured +in the config file at ``config.training.training_loss``, and +``config.training.validation_metrics``. + +The following loss functions are available by default: + +- ``WeightedMSELoss``: Latitude-weighted mean-squared-error. +- ``WeightedMAELoss``: Latitude-weighted mean-absolute-error. +- ``WeightedHuberLoss``: Latitude-weighted Huber loss. +- ``WeightedLogCoshLoss``: Latitude-weighted log-cosh loss. +- ``WeightedRMSELoss``: Latitude-weighted root-mean-squared-error. +- ``CombinedLoss``: Combined component weighted loss. + +These are available in the ``anemoi.training.losses`` module, at +``anemoi.training.losses.{short_name}.{class_name}``. + +So for example, to use the ``WeightedMSELoss`` class, you would +reference it in the config as follows: + +.. code:: yaml + + # loss function for the model + training_loss: + # loss class to initialise + _target_: anemoi.training.losses.mse.WeightedMSELoss + # loss function kwargs here + +********* + Scalars +********* + +In addition to node scaling, the loss function can also be scaled by a +scalar. These are provided by the ``Forecaster`` class, and a user can +define whether to include them in the loss function by setting +``scalars`` in the loss config dictionary. + +.. code:: yaml + + # loss function for the model + training_loss: + # loss class to initialise + _target_: anemoi.training.losses.mse.WeightedMSELoss + scalars: ['scalar1', 'scalar2'] + +Currently, the following scalars are available for use: + +- ``variable``: Scale by the feature/variable weights as defined in the + config ``config.training.variable_loss_scaling``. -The user can define their own loss function using the same structure as -the ``WeightedMSELoss`` class. +******************** + Validation Metrics +******************** -.. automodule:: anemoi.training.losses.mse +Validation metrics as defined in the config file at +``config.training.validation_metrics`` follow the same initialise +behaviour as the loss function, but can be a list. In this case all +losses are calculated and logged as a dictionary with the corresponding +name + +*********************** + Custom Loss Functions +*********************** + +Additionally, you can define your own loss function by subclassing +``BaseWeightedLoss`` and implementing the ``forward`` method, or by +subclassing ``FunctionalWeightedLoss`` and implementing the +``calculate_difference`` function. The latter abstracts the scaling, and +node weighting, and allows you to just specify the difference +calculation. + +.. code:: python + + from anemoi.training.losses.weightedloss import FunctionalWeightedLoss + + class MyLossFunction(FunctionalWeightedLoss): + def calculate_difference(self, pred, target): + return (pred - target) ** 2 + +Then in the config, set ``_target_`` to the class name, and any +additional kwargs to the loss function. + +***************** + Combined Losses +***************** + +Building on the simple single loss functions, a user can define a +combined loss, one that weights and combines multiple loss functions. + +This can be done by referencing the ``CombinedLoss`` class in the config +file, and setting the ``losses`` key to a list of loss functions to +combine. Each of those losses is then initalised just like the other +losses above. + +.. code:: yaml + + training_loss: + __target__: anemoi.training.losses.combined.CombinedLoss + losses: + - __target__: anemoi.training.losses.mse.WeightedMSELoss + - __target__: anemoi.training.losses.mae.WeightedMAELoss + scalars: ['variable'] + loss_weights: [1.0,0.5] + +All kwargs passed to ``CombinedLoss`` are passed to each of the loss +functions, and the loss weights are used to scale the individual losses +before combining them. + +.. automodule:: anemoi.training.losses.combined :members: :no-undoc-members: :show-inheritance: +******************* + Utility Functions +******************* + There is also generic functions that are useful for losses in -``aifs/losses/utils.py``. +``anemoi/training/losses/utils.py``. ``grad_scaler`` is used to automatically scale the loss gradients in the loss function using the formula in https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2. This can be switched on in the config by setting the option ``config.training.loss_gradient_scaling=True``. + +``ScaleTensor`` is a class that can record and apply arbitrary scaling +factors to tensors. It supports relative indexing, combining multiple +scalars over the same dimensions, and is only constructed at +broadcasting time, so the shape can be resolved to match the tensor +exactly. + +.. automodule:: anemoi.training.losses.utils + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/overview.rst b/docs/overview.rst index 11611b6f..268e287c 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -91,6 +91,18 @@ and resolve issues during the training process, including: - Debug configurations for quick error identification - Guidance on isolating and addressing common problems +8. Benchmarking and HPC Profiling +================================= + +Anemoi Training offers tools and configurations to support benchmarking +and High-Performance Computing (HPC) profiling, allowing users to +optimize training performance. This includes: + +- Benchmarking configurations for evaluating training efficiency across + different hardware setups. +- Profiling tools for monitoring resource utilization (CPU, GPU, + memory) and identifying performance bottlenecks. + ************************** Components and Structure ************************** diff --git a/docs/user-guide/benchmarking.rst b/docs/user-guide/benchmarking.rst new file mode 100644 index 00000000..98ac2a9d --- /dev/null +++ b/docs/user-guide/benchmarking.rst @@ -0,0 +1,746 @@ +############## + Benchmarking +############## + +*************************************** + High-level idea of the AnemoiProfiler +*************************************** + +Include a benchmark profiler that provides summary logs/statistics about +time, speed and hardware (memory, CPU/GPU usage) to profile training +runs executed with anemoi-training. Apart from those reports, it is also +possible to generate a model summary and a CUDA memory snapshot. + +- **Speed Report** - Report with metrics associated to the throughput + at training and validation time + +- **Time Report** - Report with metrics associated to the time it takes + to executes certain steps across the code + +- **Memory Report** - Report with metrics associated to GPU and CPU + memory allocation: focusing on listing those operations that are more + memory-intensive. + +- **System/hardware Report** - Report with aggregated metrics in terms + of GPU utilisation & memory usage, CPU usage (system), average disk + usage and total execution time + +- **Model Summary** - table summary with information regarding the + layers and parameters of the model. + +- **Memory (GPU) Snapshot** - memory snapshot that records the state of + allocated CUDA memory at any point in time, and optionally record the + history of allocation events that led up to that snapshot.​ + +.. figure:: ../images/profiler/anemoi_profiler_high_level.png + :alt: Schematic of the concept behind AnemoiProfiler + :align: center + +************** + How it works +************** + +Conceptual Diagram +================== + +As described in the high-level idea section the ``AnemoiProfiler`` +includes a series of features and report to help benchmark the model +training performance. Anemoi-training implementation uses PyTorch +Lightning as deep learning framework. We have designed the +AnemoiProfiler taking advantage of this functionality and building on +top of it. AnemoiProfiler then inherits from AnemoiTrainer and generate +the different reports via 3 main objects: + +- ``BenchmarkProfiler`` +- ``ProfilerProgressBar`` +- ``MemorySnapshotRecorder`` + +Each of these objects is described in more details in the sections +below. With the exception of the\ ``MemorySnapshotRecorder``, all the +above reports are defined as properties of the AnemoiProfiler. The +Memory snapshot is abstracted as an additional callback that can be +switched on/off through the config. + +- Details about the definition of AnemoiProfiler can be found in + ``src/anemoi/training/commands/profiler.py`` + +- Details about the definition of the different classes used by the + AnemoiProfiler can be found in: + ``src/anemoi/training/diagnostics/profilers.py`` + +- Details about the definition of the memory snapshot recorder: + ``src/anemoi/training/diagnostics/callbacks/__init__.py`` + +.. figure:: ../images/profiler/anemoi_profiler_architecture.png + :alt: Schematic of the AnemoiProfiler architecture + :align: center + +How to run it +============= + +The profiler has been built on top of the work already run in +anemoi-training. For that we have defined a new class ``AnemoiProfiler`` +that inherits from ``AnemoiTrainer`` where we just add new features and +methods relevant to the generation of the reports and activation of the +profiling mode. Similarly to how we do ``anemoi-trainining train`` to +submit a new training job, we had added an new command to execute a +profiler job, so we just need to do ``anemoi-training profile``. + +Following the same concept as we have with the train command, the +profiler command is also controlled via the definition of a config. For +details about the config and the different fields required please refer +to the Config section. The full command to then execute the profiler is: + +.. code:: bash + + anemoi-training profile --config-name=config.yaml + +The profiler requires certain new packages to be installed, and hence +has a specific section in the\ ``pyproject.toml`` +(``optional-dependencies.profile``). Hence the first time you'd like to +use you first need to make sure you have the dependencies installed by +doing: + +.. code:: bash + + pip install -e anemoi-training[profile] + +Config +====== + +To control the execution of the anemoi benchmark profiler, we have to +define the following fields in the eval_rollout.yaml (inside the +diagnostics folder) file under benchmark_profiler key. + +As we mentioned the benchmark profiler can generate different reports. +For each report there is an entry in the config, that decide if we want +or not to generate the report ( if ``enabled:True`` the report is +generated, if enable:False, then the report is skipped). Some reports +have additional keys: + +- For the **time report**, we can also control the length/verbosity of + the report. If ``verbose: True``, the report will provide a more + concise set of actions while if False, the report will include the + full list of profiled actions. See Time Report section for more + information + +- In the case of the **memory report**, aside from the summary + statistics the MemoryProfiler can also provide some additional + insights that include memory traces and memory timeline, those can be + switched on by settings extra_plots entry. Additional config entries: + ``warmup``, ``steps`` and ``track_rank0_only`` provide more control + regarding the generation of the memory timeline and traces and are + explained in the memory profiler section. + +- For the **(memory) snapshot**, we can also control the + length/verbosity of the report. If ``verbose: True``, the report will + provide a more concise set of actions while if False, the report will + include the full list of profiled actions. See Time Report section + for more information + +.. figure:: ../images/profiler/anemoi_profiler_config.png + :alt: AnemoiProfiler Config Settings + :align: center + +**Note** - Anemoi Training also provides some functionality for quick +troubleshooting using just the PytorchProfiler. To know more about this +you can check the Troubleshooting section. This functionality is +activated by setting ``profiler:True`` in the diagnostics config. **When +using the benchmark profiler it's not necessary to set this flag**, +since the benchmark profiler will automatically activate the +PytorchProfiler when enabling the memory profiler. When running +``anemoi-training profile`` it's then **recommended** to set +``profiler:False`` in the diagnostics config to avoid any conflicts. + +BenchmarkProfiler +================= + +The ``BenchmarkProfiler`` is the object in charge of generating the +memory report, time report, model summary and the system report. As the +diagram indicates, this class inherits from Pytorch Lightning Base +Profiler Class. Pytorch Lightning already provides built in +functionality that can be easily integrated with the Pytorch Lightning +Trainer to profile the code. In particular, it provides access to some +profilers +(https://pytorch-lightning.readthedocs.io/en/1.5.10/advanced/profiler.html) +that track performance across the training cycle in terms of execution +time (``Simple`` and ``Advanced`` Profilers) and in terms of CPU and GPU +usage (``Pytorch Profiler``). We have designed the Benchmark Profiler +taking advantage of that functionality and have extended it so it also +provides a system report and model summary. The diagram below +illustrates this. As can be seen the MemoryProfiler inherits from the +PytorchProfiler and generates the MemoryReport as main output, and the +TimeProfiler inherits from the SimlerProfiler and generates the Time +Report as output. + +.. figure:: ../images/profiler/anemoi_profiler_benchmark_profiler.png + :alt: AnemoiProfiler Config Settings + :align: center + +In the diagram, orange boxes mean output, dotted boxes refer to parent +classes. And ``get_memory_profiler_df``, ``get_time_profiler_df``, +``get_model_summary``, and ``get_system_profiler_df`` are the main +function interfaces of the BenchmarkProfiler. + +Time Report +----------- + +For the time report of our Benchmark Profiler we have decided to use the +``Simple Profiler``. This profiler provides support to profile both +callbacks, DataHooks and ModelHooks in the training and validation +loops. By default, the SimplerProfiler will record and output time +estimates for any of the callbacks, DataHooks and ModelHooks that +AnemoiTraining defines. To see this report, one just need to set in the +config ``verbose:True``. However, since this might quite extensive, +there is an option to generate a shorter and more concise version of the +time report with verbose:False, so that it focuses on the callbacks and +hooks coming from 3 main categories: + +- ``LightningDataModule (AnemoiDatasetDataModule)`` +- ``LightningModule (GraphForecaster)`` +- ``ParallelisationStrategy (DDPGroupStrategy)`` + +Aside from these 3 categories, the report also includes: + +- the execution time for the training_epoch (and training_batch) + - ``run_training_epoch/run_training_batch`` → Time it takes to + execute the 'training_step' per batch and per epoch ( check + https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loops/fit_loop.py + and + https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loops/training_epoch_loop.py + for reference) + +- the time it takes the training dataloader and validation dataloader to fetch one batch: + - `[_TrainingEpochLoop].train_dataloader_next + `_ + - `[_EvaluationLoop].val_next + `_ + +- For the callbacks, the ``SimplerProfiler`` provides time estimates of + all the different steps defined for each class, so for simplicity the + report just aggregate all those times into a single quantity (see + below example of ``AnemoiCheckpoint`` Callback) + +Below you can find an example of the report the ``Time Profiler`` issues +after its execution. + +.. figure:: ../images/profiler/example_time_report.png + :alt: AnemoiProfiler Time Report + :align: center + +Note the above example corresponds to the time report generated when +verbose is set to False according to the config settings. If verbose is +set to True, then there is no filtering applied to the actions profiled, +and the time report will include many more entries. + +System Report +------------- + +This report provides a table with summary metrics in terms of GPU +utilisation & memory usage, CPU usage (system), average disk usage and +total execution time. For now the System profiler relies on the metrics +tracked by MlFlow which is the tool we use to track out ML-experiments. +If you run the profiler without MlFlow, it would still be possible to +generate all the other reports, but the code will indicate that the +system report can't be generated. + +When running anemoi-training with MlFlow activated, then this tool also +track a set of system metrics and log them into the UI. MlFlow does this +through the `SystemMetricsMonitor +`_. +For more information you can check their docs - +https://mlflow.org/docs/latest/system-metrics/index.html + +In this report we just simply take the average of those metrics, in the +case of those associated to the GPUS we also include metrics per GPU +device. + +Below you can find an example of the ``System Report`` + +.. figure:: ../images/profiler/example_system_report.png + :alt: AnemoiProfiler System Report + :align: center + :width: 300px + +Memory Profiler +--------------- + +As we mentioned above, PTL provides functionality to profile the code. +In particular one can use the PyTorch profiler to measure the time and +memory consumption of the model’s operators +(https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html). +The report includes including GPU/CPU utilisation, memory usage, and +execution time for different operations within the model. So far we have +configured it, so that report includes the top 20 operators with the +largest GPU utilisation (Note this can be adapted and we are keen to get +feedback). + +Below you can find an example of the report generated by the ``Memory +Profiler``: + +.. figure:: ../images/profiler/example_memory_report.png + :alt: AnemoiProfiler Memory Report + :align: center + +Note the difference between self cpu time and cpu time - operators can +call other operators, self cpu time excludes time spent in children +operator calls, while total cpu time includes it. Similarly the profiler +can also show the amount of memory (used by the model’s tensors) that +was allocated (or released - negative deallocation) during the execution +of the model’s operators. In the example, ‘self’ memory corresponds to +the memory allocated (released) by the operator, excluding the children +calls to the other operators. + +To use this functionality, one just needs to specify the following +entries in the config (Benchmark Profiler section): + +.. code:: yaml + + memory: + enabled: True + steps: 6 + warmup: 2 + extra_plots: False + trace_rank0_only: True + +The enabled flag will trigger the generation of the report showed above. +Tracing all of the execution can be slow and result in very large trace +files. To avoid this, we have some optional arguments that are passed to +the profiler scheduler. + +- warming up (``warmup=2`` steps), during this phase profiler starts + tracing, but the results are discarded; this phase is used to discard + the samples obtained by the profiler at the beginning of the trace + since they are usually skewed by an extra overhead; + +- active tracing (``active=6`` steps), during this phase profiler + traces and records data; + +**Note** if you use ``limit_batches`` in the dataloader, the number of +batches selected should be greater than the sum of warmup and steps. If +not, the profiler will not be able to generate the report. + +It's possible to also generate additional products/reports with the +memory profiler, the memory timeline and the memory traces. Those take +more time to generate and hence it is possible to choose if we want them +(extra_plots: True) or not (extra_plots: False). For details about those +exact plots please check the section below about **Memory Profiler +Extras**. If using multiple GPUs, the output of the memory traces will +be significantly larger. Since usually there are certain operations that +just happen on rank 0, it might be we are just interested in the outputs +coming from this device. It's possible then to generate traces and +results just from rank 0 by settings ``trace_rank0_only`` to True. Note +if we just have one device, then this flag doesn't make any difference, +it's just relevant in case we have more than 1. + +**Note Memory Profiler - Patch** + +We identified a bug in the PytorchProfiler and we awaiting for the fix +(see `PR `_) to be +included as part of the next Pytorch Release (so far it's just included +in the nightly version). To avoid hitting the error, in the current +AnemoiProfiler we have introduce a patch (see ``PatchedProfile`` class +in the ``profilers.py`` script). This patch will be removed from the +codebase as soon as we have a new Pytorch official release that include +the fix + +**Memory Profiler Extras - Memory Traces & Memory Timeline** + +**Memory Timeline** + +PytorchProfiler automatically generates categories based on the graph of +tensor operations recorded during profiling, it's possible to visualise +this categories and its evolution across the execution using the +``export_memory_timeline`` method. You can find an example of the memory +timeline plot below (this is an example from +https://pytorch.org/blog/understanding-gpu-memory-1/ ). The exported +timeline plot is in html format. + +.. figure:: ../images/profiler/example_memory_timeline.png + :alt: Example of PytorchProfiler's Memory Timeline + :align: center + +**Memory Traces** + +The PytorchProfiler enables recording of stack traces associated with +memory allocations, and results can be outputted as a .json trace file. +The PyTorch Profiler leverages the ``Kineto`` library to collect GPU +traces. . Kineto is the subsystem within Profiler that interfaces with +CUPTI. GPU kernels execute asynchronously, and GPU-side support is +needed to create the trace. NVIDIA provides this visibility via the +CUPTI library. + +The `Kineto `_ project enables: + +- Performance observability and diagnostics across common ML bottleneck + components. +- Actionable recommendations for common issues. +- Integration of external system-level profiling tools. +- Integration with popular visualization platforms and analysis + pipelines. + +Since these traces files are complex and challenging to interpret, it's +very useful to have other supporting packages to analyse them. Holistic +Trace Analysis (HTA), it's an open source performance analysis and +visualization Python library for PyTorch users. Holistic Trace Analysis +package, provides the following features: + +- **Temporal Breakdown** - Breakdown of time taken by the GPUs in terms + of time spent in computation, communication, memory events, and idle + time across all ranks. + +- **Kernel Breakdown** - Finds kernels with the longest duration on + each rank. + +- **Kernel Duration Distribution** - Distribution of average time taken + by longest kernels across different ranks. + +- **Idle Time Breakdown** - Breakdown of GPU idle time into waiting for + the host, waiting for another kernel or attribution to an unknown + cause. + +- **Communication Computation Overlap** - Calculate the percentage of + time when communication overlaps computation. + +- **Frequent CUDA Kernel Patterns** - Find the CUDA kernels most + frequently launched by any given PyTorch or user defined operator. + +- **CUDA Kernel Launch Statistics** - Distributions of GPU kernels with + very small duration, large duration, and excessive launch time. + +- **Augmented Counters (Queue length, Memory bandwidth)** - Augmented + trace files which provide insights into memory bandwidth utilized and + number of outstanding operations on each CUDA stream. + +- **Trace Comparison** - A trace comparison tool to identify and + visualize the differences between traces. + +- **CUPTI Counter Analysis** - An experimental API to get GPU + performance counters. By attributing performance measurements from + kernels to PyTorch operators roofline analysis can be performed and + kernels can be optimized. + +To be able to load the traces and explore them using HTA, one can set up +a jupyter notebook and run: + +.. code:: python + + from hta.trace_analysis import TraceAnalysis + from pathlib import Path + from hydra import initialize, compose + from omegaconf import OmegaConf + + base_path = Path.cwd().parent + with initialize(version_base=None, config_path="./"): + cfg = compose(config_name="config.yaml") + OmegaConf.resolve(cfg) + + + # Run anemoi-training profile to generate the traces and get the run_id + run_id = "b0cc5f6fa6c0476aa1264ad7aacafb4d/" + tracepath = cfg.hardware.paths.profiler + run_id + analyzer = TraceAnalysis(trace_dir=tracepath) + + + # Temporal Breakdown + time_df = analyzer.get_temporal_breakdown() + +The function returns a dataframe containing the temporal breakdown for +each rank. See figure below. + +.. figure:: ../images/profiler/temporal_breakdown.png + :alt: Temporal Breakdown HTA Example + :align: center + +The idle time breakdown can be generated as follows: + +.. code:: python + + # Idle Time Breakdown + idle_time_df_r0 = analyzer.get_idle_time_breakdown() + +The function returns a dataframe containing the idle breakdown for each +rank. See figure below. + +.. figure:: ../images/profiler/idle_time_breakdown.png + :alt: Idle Time Breakdown HTA Example + :align: center + +Additionally, we can also look at kernel breakdown feature which breakds +down the time spent for each kernel type i.e. communication (COMM), +computation (COMP), and memory (MEM) across all ranks and presents the +proportion of time spent in each category. The percentage of time spent +in each category as a pie chart. + +.. code:: python + + # Kernel Breakdown + # NCCL changed their kernel naming convention so HTA v2.0 doesnt recognise communication kernels + # This can be fixed by editing one line of hta/utils/util.py, see https://github.com/facebookresearch/HolisticTraceAnalysis/pull/123 + + # see https://github.com/facebookresearch/HolisticTraceAnalysis/blob/main/examples/kernel_breakdown_demo.ipynb + kernel_type_metrics_df, kernel_metrics_df = analyzer.get_gpu_kernel_breakdown( + num_kernels=5, include_memory_kernels=True, visualize=True + ) + +The first dataframe returned by the function contains the raw values +used to generate the Pie chart. The second dataframe returned by +get_gpu_kernel_breakdown contains duration summary statistics for each +kernel. In particular, this includes the count, min, max, average, +standard deviation, sum and kernel type for each kernel on each rank. + +.. figure:: ../images/profiler/kernel_breakdown_dfs.png + :alt: Kernel Breakdown HTA - Dataframes Example + :align: center + +Using this data HTA creates many visualizations to identify performance +bottlenecks. + +- **Pie charts** of the top kernels for each kernel type for each rank. +- **Bar graphs** of the average duration across all ranks for each of + the top kernels and for each kernel type. + +.. figure:: ../images/profiler/kernel_breakdown_plots.png + :alt: Kernel Breakdown HTA - Plots Example + :align: center + +For more examples using HTA you can check +https://github.com/facebookresearch/HolisticTraceAnalysis/tree/main/examples +and the package docs https://hta.readthedocs.io/en/latest/. Additionally +we recommend this blog from Pytorch +https://pytorch.org/blog/trace-analysis-for-masses/ + +Model Summary +------------- + +While the ``ModelSummary`` does not fall within the category of any +report associated to computational performance, there is usually a +connection between the size of the model and it's demand for +computational resources. The ``ModelSummary`` provides a summary table +breaking down the model architecture and the number of trainable +parameters per layer. The functionality used to create this diagram +relies on https://github.com/TylerYep/torchinfo, and for the exact +details one can check the function ``get_model_summary`` defined as part +of the ``BenchmarkProfiler`` class. Below you can find an example of the +Model Summary produced. Note due to the size of the summary, the +screenshot below is truncated. + +.. figure:: ../images/profiler/example_model_summary.png + :alt: Example of AnemoiProfiler's Model Summary - Part I + :align: center + +.. figure:: ../images/profiler/example_model_summary_2.png + :alt: Example of AnemoiProfiler's Model Summary - Part II + :align: center + +ProfilerProgressBar +=================== + +**Speed Report** + +While time and speed are related, we wanted to have a separate ``Speed +Report`` that would just focus on the metrics associated to training and +validation loops throughput. To get those metrics we take advantage of +the iterations per second reported by the ``TQDMProgress`` bar, that can +be easily integrated when running a model with PTL. As indicated in the +diagram below, the ProfilerProgressBar inherits from (TQDMProgress) and +generates as main output the SpeedReport. + +The progress bar measures the iteration per second ``it/s`` by computing +the elapsed time at the start and end of each training and validation +iteration** (where iteration in this case refers to number of batches in +each epoch). The report provides an aggregated throughput by taking the +average across all epochs. Since this metric can be sensitive to the +number of samples per batch, the report includes a throughput_per_sample +where we simply just normalised the aggregated metrics taking into +account the batch size used for training and validation. Ib the case of +the dataloader(s) throughput this refers to the performance of +dataloader in terms of fetching and collating a batch, and again since +this metric can be influence by the selected batch size, we also +provided a normalised dataloader throughput. + +.. figure:: ../images/profiler/anemoi_profiler_speedreport_diagram.png + :alt: AnemoiProfiler's Speed Report Architecture + :align: center + :width: 200px + +Note, this is not just the ``training_step`` as we had recorded in the +'Time Profiler Report' but it also includes all the callbacks/hooks that +are executed during each training/validation iteration. Since most of +our callbacks are related to sanity and validation plots carried out +during the validation, we should expect lower throughputs compared to +training + +Below you can find an example of the report generated by the ``Speed +Profiler``: + +.. figure:: ../images/profiler/anemoi_profiler_speed_report.png + :alt: Example of AnemoiProfiler's Speed Report + :align: center + :width: 300px + +** CUDA and CPU total time as just time metrics (in seconds) computed by +the Memory Profiler. For now we have decided to ingrate and display them +as part of the Speed Report, but we can revisit that decision based on +user feedback + +MemorySnapshotRecorder +====================== + +With the latest pytorch versions (Pytorch equal or higher than 2.1), the +library introduces new features to analyse the GPU memory footprint. +https://pytorch.org/docs/stable/torch_cuda_memory.html#generating-a-snapshot +. The AnemoiProfiler integrates these new features through a custom +callback ``MemorySnapshotRecorder``. The memory snapshot generated is a +pickle file that records the state of allocated CUDA memory at any point +in time, and optionally record the history of allocation events that led +up to that snapshot. Captured memory snapshots will show memory events +including allocations, frees and OOMs, along with their stack traces. +The generated snapshots can then be drag and dropped onto the +interactive viewer hosted at pytorch.org/memory_viz which can be used to +explore the snapshot. To activate this callback, one just need to +specify the following entries in the config (Benchmark Profiler +section): + +.. code:: yaml + + snapshot: + enabled: True + steps: 6 + warmup: 2 + +If we don't want to generate a snapshot we simply set the ``enabled`` +flag to False. If we enable the snapshot recorder, then we need to +define the number of steps we want to record. Note a bigger number of +steps will generate a heavier file that then might take longer to render +in the website (pytorch.org/memory_viz). + +The Callback so far is defined to start tracking the CUDA memory at the +start of the training batch, when the global step matches the number of +warmup steps and end at the end of the training batch when the global +step matches the number of total steps (steps+warmup) defined. Note if +warmup is null then no warmup steps are considered, and the recording +will star as soon as the training starts. + +.. figure:: ../images/profiler/memory_snapshot_diagram.png + :alt: AnemoiProfiler's MemorySnapshotRecorder Architecture + :align: center + :width: 200px + +In the example below you can see how a ``memory snapshot`` for 6 steps +looks: + +.. figure:: ../images/profiler/memory_snapshot_output.png + :alt: Example of AnemoiProfiler's Memory Snapshot + :align: center + +******************** + Mlflow Integration +******************** + +If using MlFlow to track your run, then all the reports generated by the +profiler will also be logged into Mlflow. For now, speed, time, memory +and system reports are logged to mlflow both as json and csv files. We +hope to receive feedback about this, so in the future we can choose on +the two formats. The additional outputs generated by the memory profiler +(memory timeline are traces aren't tracked as part of mlflow due to +large size of those files). + +.. figure:: ../images/profiler/anemoi_profiler_mlflow_integration.png + :alt: AnemoiProfiler - Mlflow integration + :align: center + +One of the advantages of logging the reports as jsons, it's that those +files can be logged as ``table artifacts`` and then we can compared them +across different runs through the Evaluation tab. Below you can see an +example where we are comparing the system report metrics and speed +metrics for two different runs + +.. figure:: ../images/profiler/anemoi_profiler_mlflow_integration_2.png + :alt: AnemoiProfiler - Example Table Evaluation + :align: center + +Speed report - train/validation rates +===================================== + +When using MlFlow, there are two additional metrics that can be +explored, + +- ``training_rate`` - that's the iterations per second (it/s) recorded + by the `ProfilerProgressBar` across the training cycle. While the + SpeedReport provides the averaged throughput + `training_avg_throughput` the rate allows to see the evolution of the + throughput in time. + +- ``validation_rate`` - that's the iterations per second (it/s) + recorded by the `ProfilerProgressBar` across the validation cycle. + While the SpeedReport provides the averaged throughput + `validation_avg_throughput` the rate allows to see the evolution of + the throughput in time. + +Note - to get those metrics it's need to enable the ``SpeedProfiler``. +Below you can find an example of how the ``training_rate`` and +``validation_rate`` look like for two different runs. + +.. figure:: ../images/profiler/anemoi_profiler_training_rates.png + :alt: Example of AnemoiProfiler's Training Rates + :align: center + +.. figure:: ../images/profiler/anemoi_profiler_validation_rates.png + :alt: Example of AnemoiProfiler's Validation Rates + :align: center + +**************************** + Limitations & Improvements +**************************** + +Limitations​ +============ + +- General challenge for AI code benchmarking results → Noise coming + from hardware and AI stochastic behaviour​ + +- ``SpeedReport`` → Robustness of the metrics (val/train rates and + throughput) ​​ + +- ``TimeProfiler`` → Ability to profile just part of the code (so far + the SimplerProfiler just records 'pre-defined' hardcoded actions + according to the PROFILER_ACTIONS defined in the codebase. And as + mentioned above those actions need to be a DataHook, ModelHook or + Callback. ​ + +- ``TimeProfiler`` → Limitations to time asyncronous part of the code​ + +- ``MemoryProfiler`` → Report requires good understanding of pytorch + profiler model's operators + +- ``SpeedReport`` → Train/val rates categorisation + +Improvements​​ +============== + +- https://pytorch.org/tutorials/recipes/recipes/benchmark.html​ + +- Decorator style to do partial profiling - + https://github.com/pythonprofilers/memory_profiler or + https://github.com/pyutils/line_profiler + +- Defining a decorator o wrapper for the ``TimeProfiler`` could be + helpful to provide more control and access to time profiling other + parts of the codebase​ + +- Asynchronous code profiling -> https://github.com/sumerc/yappi​ + +- Performance benchmarking and integration with CI/CD - possibility to + run the profiler for different code releases as part of github + actions​ + +- Energy reports ​ + +- Better compatibility with other hardware ( AMD GPUs, IPUs, etc). - + System metrics monitor might not work out of the box with other + hardware different from Nvidia, since the library it uses to record + the gpu metrics it's pynvml. We could extend the functionality to be + able to profile other hardware like AMS GPUs or Graphcore IPUs + +- Support other components of Anemoi like ``anemoi-inference`` diff --git a/docs/user-guide/configuring.rst b/docs/user-guide/configuring.rst index 35efec3f..307cebf0 100644 --- a/docs/user-guide/configuring.rst +++ b/docs/user-guide/configuring.rst @@ -21,7 +21,7 @@ settings at the top as follows: defaults: - data: zarr - dataloader: native_grid - - diagnostics: eval_rollout + - diagnostics: evaluation - hardware: example - graph: multi_scale - model: gnn @@ -100,7 +100,7 @@ match the dataset you provide. defaults: - data: zarr - dataloader: native_grid - - diagnostics: eval_rollout + - diagnostics: evaluation - hardware: example - graph: multi_scale - model: transformer # Change from default group diff --git a/docs/user-guide/debugging.rst b/docs/user-guide/debugging.rst index 38856620..b293e3fd 100644 --- a/docs/user-guide/debugging.rst +++ b/docs/user-guide/debugging.rst @@ -142,7 +142,14 @@ Turn off plotting callbacks to isolate non-visualization related issues: diagnostics: plot: - enabled: false + callbacks: [] + +Or set the plot config to none, (in diagnostics.evaluation) + +.. code:: yaml + + defaults: + plot: none ********************************** Debugging C10 Distributed Errors diff --git a/docs/user-guide/tracking.rst b/docs/user-guide/tracking.rst index cab5e851..f97182d7 100644 --- a/docs/user-guide/tracking.rst +++ b/docs/user-guide/tracking.rst @@ -33,7 +33,7 @@ the same experiment. Within the MLflow experiments tab, it is possible to define different namespaces. To create a new namespace, the user just needs to pass an 'experiment_name' -(``config.diagnostics.eval_rollout.log.mlflow.experiment_name``) to the +(``config.diagnostics.evaluation.log.mlflow.experiment_name``) to the mlflow logger. **Parent-Child Runs** diff --git a/docs/user-guide/training.rst b/docs/user-guide/training.rst index 695720b4..5be08222 100644 --- a/docs/user-guide/training.rst +++ b/docs/user-guide/training.rst @@ -172,8 +172,8 @@ by setting ``config.data.normaliser``, such that: It is possible to change the weighting given to each of the variables in the loss function by changing -``config.training.loss_scaling.pl.`` and -``config.training.loss_scaling.sfc.``. +``config.training.variable_loss_scaling.pl.`` +and ``config.training.variable_loss_scaling.sfc.``. It is also possible to change the scaling given to the pressure levels using ``config.training.pressure_level_scaler``. For almost all diff --git a/pyproject.toml b/pyproject.toml index 9ffa2db1..f3e730d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,12 @@ -#!/usr/bin/env python -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ - [build-system] requires = [ "setuptools>=60", "setuptools-scm>=8" ] @@ -43,14 +41,14 @@ dynamic = [ "version" ] dependencies = [ "anemoi-datasets>=0.4", - "anemoi-graphs", + "anemoi-graphs>=0.4", "anemoi-models>=0.3", - "anemoi-utils[provenance]>=0.3.10", + "anemoi-utils[provenance]>=0.4.4", "einops>=0.6.1", "hydra-core>=1.3", "matplotlib>=3.7.1", "mlflow>=2.11.1", - "numpy<2", # Pinned until we can confirm it works with anemoi graphs + "numpy<2", # Pinned until we can confirm it works with anemoi graphs "pynvml>=11.5", "pyshtools>=4.10.4", "pytorch-lightning>=2.1", @@ -76,6 +74,13 @@ optional-dependencies.docs = [ "sphinx-argparse", "sphinx-rtd-theme", ] +optional-dependencies.profile = [ + "holistictraceanalysis>=0.2", + "pandas>=1.3.2", + "rich>=13.6", + "tabulate>=0.9", +] + optional-dependencies.tests = [ "hypothesis", "pytest", "pytest-mock" ] urls.Changelog = "https://github.com/ecmwf/anemoi-training/CHANGELOG.md" @@ -85,8 +90,9 @@ urls.Issues = "https://github.com/ecmwf/anemoi-training/issues" urls.Repository = "https://github.com/ecmwf/anemoi-training/" # command for interactive DDP (not supposed to be used directly) # the dot is intentional, so it doesn't trigger autocomplete +# Files need to be named profiler due to A005 Module `profile` is shadowing a Python builtin module +scripts.".anemoi-training-profile" = "anemoi.training.commands.profiler:main" scripts.".anemoi-training-train" = "anemoi.training.commands.train:main" - # Add subcommand in the `commands` directory scripts.anemoi-training = "anemoi.training.__main__:main" diff --git a/src/anemoi/training/__init__.py b/src/anemoi/training/__init__.py index 9733be26..af4a3aea 100644 --- a/src/anemoi/training/__init__.py +++ b/src/anemoi/training/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. @@ -9,7 +11,7 @@ try: # NOTE: the `_version.py` file must not be present in the git repository # as it is generated by setuptools at install time - from ._version import __version__ # type: ignore + from ._version import __version__ except ImportError: # pragma: no cover # Local copy or not installed with setuptools __version__ = "999" diff --git a/src/anemoi/training/__main__.py b/src/anemoi/training/__main__.py index f571e1b1..8de7a257 100644 --- a/src/anemoi/training/__main__.py +++ b/src/anemoi/training/__main__.py @@ -1,11 +1,11 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# from anemoi.utils.cli import cli_main from anemoi.utils.cli import make_parser diff --git a/src/anemoi/training/commands/__init__.py b/src/anemoi/training/commands/__init__.py index 17413995..c8fb7a99 100644 --- a/src/anemoi/training/commands/__init__.py +++ b/src/anemoi/training/commands/__init__.py @@ -1,11 +1,11 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# from pathlib import Path diff --git a/src/anemoi/training/commands/profiler.py b/src/anemoi/training/commands/profiler.py new file mode 100644 index 00000000..5ec9dfa0 --- /dev/null +++ b/src/anemoi/training/commands/profiler.py @@ -0,0 +1,47 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from __future__ import annotations + +import logging +import os +import sys + +from anemoi.training.commands.train import TrainBase + +LOGGER = logging.getLogger(__name__) + + +class Profile(TrainBase): + """Commands to profile Anemoi models.""" + + accept_unknown_args = True + command = "profile" + + def run(self, args: list[str], unknown_args: list[str] | None = None) -> None: + # This will be picked up by the logger + self.prepare_sysargv(args, unknown_args) + + LOGGER.info("Running anemoi profile command with overrides: %s", sys.argv[1:]) + main() + + +def main() -> None: + # Use the environment variable to check if main is being called from the subcommand, not from the ddp entrypoint + if not os.environ.get("ANEMOI_TRAINING_CMD"): + error = "This entrypoint should not be called directly. Use `anemoi-training profiler` instead." + raise RuntimeError(error) + + from anemoi.training.train.profiler import main as anemoi_profile + + anemoi_profile() + + +command = Profile diff --git a/src/anemoi/training/commands/train.py b/src/anemoi/training/commands/train.py index 44ce186f..0cf4b2f6 100644 --- a/src/anemoi/training/commands/train.py +++ b/src/anemoi/training/commands/train.py @@ -13,6 +13,8 @@ import logging import os import sys +from abc import ABC +from abc import abstractmethod from pathlib import Path from typing import TYPE_CHECKING @@ -24,30 +26,13 @@ LOGGER = logging.getLogger(__name__) -class Train(Command): - """Commands to train Anemoi models.""" - +class TrainBase(Command, ABC): accept_unknown_args = True @staticmethod def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: return parser - def run(self, args: argparse.Namespace, unknown_args: list[str] | None = None) -> None: - # This will be picked up by the logger - os.environ["ANEMOI_TRAINING_CMD"] = f"{sys.argv[0]} {args.command}" - # Merge the known subcommands with a non-whitespace character for hydra - new_sysargv = self._merge_sysargv(args) - - # Add the unknown arguments (belonging to hydra) to sys.argv - if unknown_args is not None: - sys.argv = [new_sysargv, *unknown_args] - else: - sys.argv = [new_sysargv] - - LOGGER.info("Running anemoi training command with overrides: %s", sys.argv[1:]) - main() - def _merge_sysargv(self, args: argparse.Namespace) -> str: """Merge the sys.argv with the known subcommands to pass to hydra. @@ -74,6 +59,31 @@ def _merge_sysargv(self, args: argparse.Namespace) -> str: modified_sysargv += f"-{args.subcommand}" return str(modified_sysargv) + def prepare_sysargv(self, args: argparse.Namespace, unknown_args: list[str] | None = None) -> None: + os.environ["ANEMOI_TRAINING_CMD"] = f"{sys.argv[0]} {args.command}" + # Merge the known subcommands with a non-whitespace character for hydra + new_sysargv = self._merge_sysargv(args) + + # Add the unknown arguments (belonging to hydra) to sys.argv + if unknown_args is not None: + sys.argv = [new_sysargv, *unknown_args] + else: + sys.argv = [new_sysargv] + + @abstractmethod + def run(self, args: argparse.Namespace, unknown_args: list[str] | None = None) -> None: ... + + +class Train(TrainBase): + """Commands to train Anemoi models.""" + + def run(self, args: argparse.Namespace, unknown_args: list[str] | None = None) -> None: + # This will be picked up by the logger + self.prepare_sysargv(args, unknown_args) + + LOGGER.info("Running anemoi training command with overrides: %s", sys.argv[1:]) + main() + def main() -> None: # Use the environment variable to check if main is being called from the subcommand, not from the ddp entrypoint diff --git a/src/anemoi/training/config/__init__.py b/src/anemoi/training/config/__init__.py index 282d6a69..c167afa2 100644 --- a/src/anemoi/training/config/__init__.py +++ b/src/anemoi/training/config/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/config/config.yaml b/src/anemoi/training/config/config.yaml index 2045da93..a379acfd 100644 --- a/src/anemoi/training/config/config.yaml +++ b/src/anemoi/training/config/config.yaml @@ -1,7 +1,7 @@ defaults: - data: zarr - dataloader: native_grid -- diagnostics: eval_rollout +- diagnostics: evaluation - hardware: example - graph: multi_scale - model: gnn diff --git a/src/anemoi/training/config/data/zarr.yaml b/src/anemoi/training/config/data/zarr.yaml index 1657861f..3b9a4537 100644 --- a/src/anemoi/training/config/data/zarr.yaml +++ b/src/anemoi/training/config/data/zarr.yaml @@ -30,6 +30,21 @@ remapped: normalizer: default: "mean-std" + + # Remap cp statistics to those of tp when using FractionBounding. This ensures + # that cp, as a fraction of tp, remains consistent with tp's scale and statistics. + # NOTE: This remap should only be applied if FractionBounding is enabled for cp. + # remap: + # cp: tp + + # Standardization applied to tp and cp variables. Ensure that if cp is bounded + # as a fraction of tp, both variables are normalized using these shared statistics. + # "Std" normalization is preferred here over "mean-std" to avoid shifting of the + # zero value in the normalized space. + std: + - "tp" + # - "cp" + min-max: max: - "sdor" diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index e6d50801..d7aa4f6d 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -45,6 +45,8 @@ training: frequency: ${data.frequency} drop: [] +validation_rollout: 1 # number of rollouts to use for validation, must be equal or greater than rollout expected by callbacks + validation: dataset: ${dataloader.dataset} start: 2021 diff --git a/src/anemoi/training/config/debug.yaml b/src/anemoi/training/config/debug.yaml index 5be3e9f4..a6143bb6 100644 --- a/src/anemoi/training/config/debug.yaml +++ b/src/anemoi/training/config/debug.yaml @@ -1,7 +1,7 @@ defaults: - data: zarr - dataloader: native_grid -- diagnostics: eval_rollout +- diagnostics: evaluation - hardware: example - graph: multi_scale - model: gnn @@ -18,7 +18,7 @@ defaults: diagnostics: plot: - enabled: False + callbacks: [] hardware: files: graph: ??? diff --git a/src/anemoi/training/config/diagnostics/benchmark_profiler/detailed.yaml b/src/anemoi/training/config/diagnostics/benchmark_profiler/detailed.yaml new file mode 100644 index 00000000..486c1851 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/benchmark_profiler/detailed.yaml @@ -0,0 +1,20 @@ +# Use anemoi-profile to profile the training process +memory: + enabled: True + steps: 5 # wait warmup steps and then do steps (too many steps would lead to a big file) + warmup: 2 + extra_plots: False + trace_rank0_only: False #set to true and it will profile rank 0 only. Reads SLURM_PROC_ID so won't work when not running via Slurm +time: + enabled: True + verbose: False #If true, output every action the profiler caputres, otherwise output a subset defined in PROFILER_ACTIONS at the top of aifs/diagnostics/profiler.py +speed: + enabled: True +system: + enabled: True +model_summary: + enabled: True +snapshot: + enabled: True + steps: 4 # wait warmup steps and then do steps + warmup: 0 diff --git a/src/anemoi/training/config/diagnostics/benchmark_profiler/simple.yaml b/src/anemoi/training/config/diagnostics/benchmark_profiler/simple.yaml new file mode 100644 index 00000000..34c8023d --- /dev/null +++ b/src/anemoi/training/config/diagnostics/benchmark_profiler/simple.yaml @@ -0,0 +1,20 @@ +# Use anemoi-profile to profile the training process +memory: + enabled: False + steps: 5 # wait warmup steps and then do steps (too many steps would lead to a big file) + warmup: 2 + extra_plots: False + trace_rank0_only: False #set to true and it will profile rank 0 only. Reads SLURM_PROC_ID so won't work when not running via Slurm +time: + enabled: True + verbose: False #If true, output every action the profiler caputres, otherwise output a subset defined in PROFILER_ACTIONS at the top of aifs/diagnostics/profiler.py +speed: + enabled: True +system: + enabled: False +model_summary: + enabled: False +snapshot: + enabled: False + steps: 4 # wait warmup steps and then do steps + warmup: 0 diff --git a/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml b/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml new file mode 100644 index 00000000..1eb35f69 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml @@ -0,0 +1 @@ +# Add callbacks here diff --git a/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml new file mode 100644 index 00000000..6afa04dc --- /dev/null +++ b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml @@ -0,0 +1,4 @@ +# Add callbacks here +- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval + rollout: ${dataloader.validation_rollout} + every_n_batches: 20 diff --git a/src/anemoi/training/config/diagnostics/eval_rollout.yaml b/src/anemoi/training/config/diagnostics/evaluation.yaml similarity index 53% rename from src/anemoi/training/config/diagnostics/eval_rollout.yaml rename to src/anemoi/training/config/diagnostics/evaluation.yaml index 50e9a647..d138d619 100644 --- a/src/anemoi/training/config/diagnostics/eval_rollout.yaml +++ b/src/anemoi/training/config/diagnostics/evaluation.yaml @@ -1,53 +1,8 @@ --- -eval: - enabled: False - # use this to evaluate the model over longer rollouts, every so many validation batches - rollout: 12 - frequency: 20 -plot: - enabled: True - asynchronous: True - frequency: 750 - sample_idx: 0 - per_sample: 6 - parameters: - - z_500 - - t_850 - - u_850 - - v_850 - - 2t - - 10u - - 10v - - sp - - tp - - cp - #Defining the accumulation levels for precipitation related fields and the colormap - accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm - cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] - precip_and_related_fields: [tp, cp] - # Histogram and Spectrum plots - parameters_histogram: - - z_500 - - tp - - 2t - - 10u - - 10v - parameters_spectrum: - - z_500 - - tp - - 2t - - 10u - - 10v - # group parameters by categories when visualizing contributions to the loss - # one-parameter groups are possible to highlight individual parameters - parameter_groups: - moisture: [tp, cp, tcw] - sfc_wind: [10u, 10v] - learned_features: False - longrollout: - enabled: False - rollout: [60] - frequency: 20 # every X epochs +defaults: + - plot: detailed + - callbacks: pretraining + - benchmark_profiler: detailed debug: # this will detect and trace back NaNs / Infs etc. but will slow down training @@ -57,6 +12,7 @@ debug: # remember to also activate the tensorboard logger (below) profiler: False +enable_checkpointing: True checkpoint: every_n_minutes: save_frequency: 30 # Approximate, as this is checked at the end of training steps @@ -94,6 +50,8 @@ log: terminal: True run_name: null # If set to null, the run name will be the a random UUID on_resume_create_child: True + expand_hyperparams: # Which keys in hyperparams to expand + - config interval: 100 enable_progress_bar: True diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml new file mode 100644 index 00000000..b759c17b --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -0,0 +1,61 @@ +asynchronous: True # Whether to plot asynchronously +frequency: # Frequency of the plotting + batch: 750 + epoch: 5 + +# Parameters to plot +parameters: +- z_500 +- t_850 +- u_850 +- v_850 +- 2t +- 10u +- 10v +- sp +- tp +- cp + +# Sample index +sample_idx: 0 + +# Precipitation and related fields +precip_and_related_fields: [tp, cp] + +callbacks: + # Add plot callbacks here + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot + every_n_epochs: 5 + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: ${diagnostics.plot.sample_idx} + per_sample : 6 + parameters: ${diagnostics.plot.parameters} + #Defining the accumulation levels for precipitation related fields and the colormap + accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm + cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum + # every_n_batches: 100 # Override for batch frequency + sample_idx: ${diagnostics.plot.sample_idx} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram + sample_idx: ${diagnostics.plot.sample_idx} + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v diff --git a/src/anemoi/training/config/diagnostics/plot/none.yaml b/src/anemoi/training/config/diagnostics/plot/none.yaml new file mode 100644 index 00000000..3101f292 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/none.yaml @@ -0,0 +1 @@ +callbacks: [] diff --git a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml new file mode 100644 index 00000000..642e6e6b --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml @@ -0,0 +1,67 @@ +asynchronous: True # Whether to plot asynchronously +frequency: # Frequency of the plotting + batch: 750 + epoch: 5 + +# Parameters to plot +parameters: +- z_500 +- t_850 +- u_850 +- v_850 +- 2t +- 10u +- 10v +- sp +- tp +- cp + +# Sample index +sample_idx: 0 + +# Precipitation and related fields +precip_and_related_fields: [tp, cp] + +callbacks: + # Add plot callbacks here + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot + every_n_epochs: 5 + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: ${diagnostics.plot.sample_idx} + per_sample : 6 + parameters: ${diagnostics.plot.parameters} + #Defining the accumulation levels for precipitation related fields and the colormap + accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm + cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum + # every_n_batches: 100 # Override for batch frequency + sample_idx: ${diagnostics.plot.sample_idx} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram + sample_idx: ${diagnostics.plot.sample_idx} + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v + - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots + rollout: + - ${dataloader.validation_rollout} + every_n_epochs: 20 + sample_idx: ${diagnostics.plot.sample_idx} + parameters: ${diagnostics.plot.parameters} diff --git a/src/anemoi/training/config/diagnostics/plot/simple.yaml b/src/anemoi/training/config/diagnostics/plot/simple.yaml new file mode 100644 index 00000000..2a987ccb --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/simple.yaml @@ -0,0 +1,40 @@ +asynchronous: True # Whether to plot asynchronously +frequency: # Frequency of the plotting + batch: 750 + epoch: 10 + +# Parameters to plot +parameters: +- z_500 +- t_850 +- u_850 +- v_850 +- 2t +- 10u +- 10v +- sp +- tp +- cp + +# Sample index +sample_idx: 0 + +# Precipitation and related fields +precip_and_related_fields: [tp, cp] + +callbacks: + # Add plot callbacks here + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: ${diagnostics.plot.sample_idx} + per_sample : 6 + parameters: ${diagnostics.plot.parameters} + #Defining the accumulation levels for precipitation related fields and the colormap + accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm + cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} diff --git a/src/anemoi/training/config/graph/limited_area.yaml b/src/anemoi/training/config/graph/limited_area.yaml new file mode 100644 index 00000000..f17bc384 --- /dev/null +++ b/src/anemoi/training/config/graph/limited_area.yaml @@ -0,0 +1,60 @@ +--- +overwrite: True + +data: "data" +hidden: "hidden" + +nodes: + # Data nodes + data: + node_builder: + _target_: anemoi.graphs.nodes.ZarrDatasetNodes + dataset: ${dataloader.training.dataset} + attributes: ${graph.attributes.nodes} + # Hidden nodes + hidden: + node_builder: + _target_: anemoi.graphs.nodes.LimitedAreaTriNodes # options: ZarrDatasetNodes, NPZFileNodes, TriNodes + resolution: 5 # grid resolution for npz (o32, o48, ...) + reference_node_name: ${graph.data} + mask_attr_name: cutout + +edges: +# Encoder configuration +- source_name: ${graph.data} + target_name: ${graph.hidden} + edge_builder: + _target_: anemoi.graphs.edges.CutOffEdges # options: KNNEdges, CutOffEdges + cutoff_factor: 0.6 # only for cutoff method + attributes: ${graph.attributes.edges} +# Processor configuration +- source_name: ${graph.hidden} + target_name: ${graph.hidden} + edge_builder: + _target_: anemoi.graphs.edges.MultiScaleEdges + x_hops: 1 + attributes: ${graph.attributes.edges} +# Decoder configuration +- source_name: ${graph.hidden} + target_name: ${graph.data} + target_mask_attr_name: cutout + edge_builder: + _target_: anemoi.graphs.edges.KNNEdges # options: KNNEdges, CutOffEdges + num_nearest_neighbours: 3 # only for knn method + attributes: ${graph.attributes.edges} + + +attributes: + nodes: + area_weight: + _target_: anemoi.graphs.nodes.attributes.AreaWeights # options: Area, Uniform + norm: unit-max # options: l1, l2, unit-max, unit-sum, unit-std + cutout: + _target_: anemoi.graphs.nodes.attributes.CutOutMask + edges: + edge_length: + _target_: anemoi.graphs.edges.attributes.EdgeLength + norm: unit-std + edge_dirs: + _target_: anemoi.graphs.edges.attributes.EdgeDirection + norm: unit-std diff --git a/src/anemoi/training/config/graph/stretched_grid.yaml b/src/anemoi/training/config/graph/stretched_grid.yaml new file mode 100644 index 00000000..dad0172d --- /dev/null +++ b/src/anemoi/training/config/graph/stretched_grid.yaml @@ -0,0 +1,63 @@ +# Stretched grid graph config intended to be used with a cutout dataset. +# The stretched mesh resolution used here is intended for o96 global resolution with 10km +# limited area resolution. +overwrite: False + +data: "data" +hidden: "hidden" + +nodes: + data: + node_builder: + _target_: anemoi.graphs.nodes.ZarrDatasetNodes + dataset: ${dataloader.training.dataset} + attributes: + area_weight: + _target_: anemoi.graphs.nodes.attributes.AreaWeights + norm: unit-max + cutout: + _target_: anemoi.graphs.nodes.attributes.CutOutMask + hidden: + node_builder: + _target_: anemoi.graphs.nodes.StretchedTriNodes + lam_resolution: 8 + global_resolution: 5 + reference_node_name: ${graph.data} + mask_attr_name: cutout + margin_radius_km: 11 + attributes: + area_weights: + _target_: anemoi.graphs.nodes.attributes.AreaWeights + norm: unit-max + +edges: +# Encoder +- source_name: ${graph.data} + target_name: ${graph.hidden} + edge_builder: + _target_: anemoi.graphs.edges.KNNEdges + num_nearest_neighbours: 12 + attributes: ${graph.attributes.edges} +# Processor +- source_name: ${graph.hidden} + target_name: ${graph.hidden} + edge_builder: + _target_: anemoi.graphs.edges.MultiScaleEdges + x_hops: 1 + attributes: ${graph.attributes.edges} +# Decoder +- source_name: ${graph.hidden} + target_name: ${graph.data} + edge_builder: + _target_: anemoi.graphs.edges.KNNEdges + num_nearest_neighbours: 3 + attributes: ${graph.attributes.edges} + +attributes: + edges: + edge_length: + _target_: anemoi.graphs.edges.attributes.EdgeLength + norm: unit-max + edge_dirs: + _target_: anemoi.graphs.edges.attributes.EdgeDirection + norm: unit-std diff --git a/src/anemoi/training/config/model/gnn.yaml b/src/anemoi/training/config/model/gnn.yaml index a01bf860..4f4c176c 100644 --- a/src/anemoi/training/config/model/gnn.yaml +++ b/src/anemoi/training/config/model/gnn.yaml @@ -46,3 +46,26 @@ attributes: nodes: [] node_loss_weight: area_weight + +# Bounding configuration +bounding: #These are applied in order + + # Bound tp (total precipitation) with a Relu bounding layer + # ensuring a range of [0, infinity) to avoid negative precipitation values. + - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) + variables: + - tp + + # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. + # This guarantees that cp is physically consistent with tp by restricting cp + # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. + # NOTE: If this bounding strategy is used, the normalization of cp must be + # changed to "std" normalization, and the "cp" statistics should be remapped + # to those of tp to ensure consistency. + + # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp + # variables: + # - cp + # min_val: 0 + # max_val: 1 + # total_var: tp diff --git a/src/anemoi/training/config/model/graphtransformer.yaml b/src/anemoi/training/config/model/graphtransformer.yaml index 610de803..5c2e819a 100644 --- a/src/anemoi/training/config/model/graphtransformer.yaml +++ b/src/anemoi/training/config/model/graphtransformer.yaml @@ -51,3 +51,26 @@ attributes: nodes: [] node_loss_weight: area_weight + +# Bounding configuration +bounding: #These are applied in order + + # Bound tp (total precipitation) with a Relu bounding layer + # ensuring a range of [0, infinity) to avoid negative precipitation values. + - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) + variables: + - tp + + # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. + # This guarantees that cp is physically consistent with tp by restricting cp + # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. + # NOTE: If this bounding strategy is used, the normalization of cp must be + # changed to "std" normalization, and the "cp" statistics should be remapped + # to those of tp to ensure consistency. + + # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp + # variables: + # - cp + # min_val: 0 + # max_val: 1 + # total_var: tp diff --git a/src/anemoi/training/config/model/transformer.yaml b/src/anemoi/training/config/model/transformer.yaml index 7c490990..b26c9ecc 100644 --- a/src/anemoi/training/config/model/transformer.yaml +++ b/src/anemoi/training/config/model/transformer.yaml @@ -50,3 +50,26 @@ attributes: nodes: [] node_loss_weight: area_weight + +# Bounding configuration +bounding: #These are applied in order + + # Bound tp (total precipitation) with a Relu bounding layer + # ensuring a range of [0, infinity) to avoid negative precipitation values. + - _target_: anemoi.models.layers.bounding.ReluBounding #[0, infinity) + variables: + - tp + + # [OPTIONAL] Bound cp (convective precipitation) as a fraction of tp. + # This guarantees that cp is physically consistent with tp by restricting cp + # to a fraction of tp [0 to 1]. Uncomment the lines below to apply. + # NOTE: If this bounding strategy is used, the normalization of cp must be + # changed to "std" normalization, and the "cp" statistics should be remapped + # to those of tp to ensure consistency. + + # - _target_: anemoi.models.layers.bounding.FractionBounding # fraction of tp + # variables: + # - cp + # min_val: 0 + # max_val: 1 + # total_var: tp diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 870eeb7a..b471034e 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -19,6 +19,8 @@ multistep_input: 2 # the effective batch size becomes num-devices * batch_size * k accum_grad_batches: 1 +num_sanity_val_steps: 6 + # clipp gradients, 0 : don't clip, default algorithm: norm, alternative: value gradient_clip: val: 32. @@ -33,11 +35,36 @@ swa: # use ZeroRedundancyOptimizer ; saves memory for larger models zero_optimizer: False +# loss functions + # dynamic rescaling of the loss gradient # see https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2 # don't enable this by default until it's been tested and proven beneficial + +# loss function for the model +training_loss: + # loss class to initialise + _target_: anemoi.training.losses.mse.WeightedMSELoss + # Scalars to include in loss calculation + # Available scalars include: + # - 'variable': See `variable_loss_scaling` for more information + scalars: ['variable'] + ignore_nans: False + loss_gradient_scaling: False +# Validation metrics calculation, +# This may be a list, in which case all metrics will be calculated +# and logged according to their name +validation_metrics: + # loss class to initialise + - _target_: anemoi.training.losses.mse.WeightedMSELoss + # Scalars to include in loss calculation + # Available scalars include, 'variable' + scalars: [] + # other kwargs + ignore_nans: True + # length of the "rollout" window (see Keisler's paper) rollout: start: 1 @@ -46,17 +73,22 @@ rollout: # maximum rollout to use max: 1 -max_epochs: 200 +# Set max_epochs or max_steps. Training stops at the first limit reached. +max_epochs: null +max_steps: 150000 + lr: rate: 0.625e-4 #local_lr - iterations: 300000 + iterations: ${training.max_steps} # NOTE: When max_epochs < max_steps, scheduler will run for max_steps min: 3e-7 #Not scaled by #GPU # Changes in per-gpu batch_size should come with a rescaling of the local_lr # in order to keep a constant global_lr # global_lr = local_lr * num_gpus_per_node * num_nodes / gpus_per_model -loss_scaling: +# Variable loss scaling +# 'variable' must be included in `scalars` in the losses for this to be applied. +variable_loss_scaling: default: 1 pl: q: 0.6 #1 diff --git a/src/anemoi/training/data/__init__.py b/src/anemoi/training/data/__init__.py index 282d6a69..c167afa2 100644 --- a/src/anemoi/training/data/__init__.py +++ b/src/anemoi/training/data/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index f64a3091..0d3d1b3f 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -16,6 +16,7 @@ import pytorch_lightning as pl from anemoi.datasets.data import open_dataset from anemoi.models.data_indices.collection import IndexCollection +from anemoi.utils.dates import frequency_to_seconds from omegaconf import DictConfig from omegaconf import OmegaConf from torch.utils.data import DataLoader @@ -42,23 +43,6 @@ def __init__(self, config: DictConfig) -> None: self.config = config - # Determine the step size relative to the data frequency - frequency = self.config.data.frequency - timestep = self.config.data.timestep - assert ( - isinstance(frequency, str) and isinstance(timestep, str) and frequency[-1] == "h" and timestep[-1] == "h" - ), f"Error in format of timestep, {timestep}, or data frequency, {frequency}" - assert ( - int(timestep[:-1]) % int(frequency[:-1]) == 0 - ), f"Timestep isn't a multiple of data frequency, {timestep}, or data frequency, {frequency}" - self.timeincrement = int(timestep[:-1]) // int(frequency[:-1]) - LOGGER.info( - "Timeincrement set to %s for data with frequency, %s, and timestep, %s", - self.timeincrement, - frequency, - timestep, - ) - self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank self.model_comm_group_id = ( self.global_rank // self.config.hardware.num_gpus_per_model @@ -119,6 +103,34 @@ def metadata(self) -> dict: def data_indices(self) -> IndexCollection: return IndexCollection(self.config, self.ds_train.name_to_index) + @cached_property + def timeincrement(self) -> int: + """Determine the step size relative to the data frequency.""" + try: + frequency = frequency_to_seconds(self.config.data.frequency) + except ValueError as e: + msg = f"Error in data frequency, {self.config.data.frequency}" + raise ValueError(msg) from e + + try: + timestep = frequency_to_seconds(self.config.data.timestep) + except ValueError as e: + msg = f"Error in timestep, {self.config.data.timestep}" + raise ValueError(msg) from e + + assert timestep % frequency == 0, ( + f"Timestep ({self.config.data.timestep} == {timestep}) isn't a " + f"multiple of data frequency ({self.config.data.frequency} == {frequency})." + ) + + LOGGER.info( + "Timeincrement set to %s for data with frequency, %s, and timestep, %s", + timestep // frequency, + frequency, + timestep, + ) + return timestep // frequency + @cached_property def ds_train(self) -> NativeGridDataset: return self._get_dataset( @@ -129,10 +141,8 @@ def ds_train(self) -> NativeGridDataset: @cached_property def ds_valid(self) -> NativeGridDataset: r = self.rollout - if self.config.diagnostics.eval.enabled: - r = max(r, self.config.diagnostics.eval.rollout) - if self.config.diagnostics.plot.get("longrollout") and self.config.diagnostics.plot.longrollout.enabled: - r = max(r, max(self.config.diagnostics.plot.longrollout.rollout)) + r = max(r, self.config.dataloader.get("validation_rollout", 1)) + assert self.config.dataloader.training.end < self.config.dataloader.validation.start, ( f"Training end date {self.config.dataloader.training.end} is not before" f"validation start date {self.config.dataloader.validation.start}" diff --git a/src/anemoi/training/diagnostics/__init__.py b/src/anemoi/training/diagnostics/__init__.py index 282d6a69..c167afa2 100644 --- a/src/anemoi/training/diagnostics/__init__.py +++ b/src/anemoi/training/diagnostics/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index cf085eab..f3597843 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -1,1057 +1,66 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# * [WHY ARE CALLBACKS UNDER __init__.py?] -# * This functionality will be restructured in the near future -# * so for now callbacks are under __init__.py - from __future__ import annotations -import copy import logging -import sys -import time -import traceback -import uuid -from abc import ABC -from abc import abstractmethod -from concurrent.futures import ThreadPoolExecutor -from contextlib import nullcontext +from collections.abc import Iterable from datetime import timedelta -from functools import cached_property -from pathlib import Path from typing import TYPE_CHECKING from typing import Any from typing import Callable -import matplotlib.patches as mpatches -import matplotlib.pyplot as plt -import numpy as np -import torch -import torchinfo -from anemoi.utils.checkpoints import save_metadata -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.utilities import rank_zero_only +from hydra.utils import instantiate +from omegaconf import DictConfig -from anemoi.training.diagnostics.plots import init_plot_settings -from anemoi.training.diagnostics.plots import plot_graph_features -from anemoi.training.diagnostics.plots import plot_histogram -from anemoi.training.diagnostics.plots import plot_loss -from anemoi.training.diagnostics.plots import plot_power_spectrum -from anemoi.training.diagnostics.plots import plot_predicted_multilevel_flat_sample +from anemoi.training.diagnostics.callbacks.checkpoint import AnemoiCheckpoint +from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor +from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging +from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback if TYPE_CHECKING: - import pytorch_lightning as pl - from omegaconf import DictConfig - from omegaconf import OmegaConf + from pytorch_lightning.callbacks import Callback LOGGER = logging.getLogger(__name__) -class ParallelExecutor(ThreadPoolExecutor): - """Wraps parallel execution and provides accurate information about errors. - - Extends ThreadPoolExecutor to preserve the original traceback and line number. - - Reference: https://stackoverflow.com/questions/19309514/getting-original-line- - number-for-exception-in-concurrent-futures/24457608#24457608 - """ - - def submit(self, fn: Any, *args, **kwargs) -> Callable: - """Submits the wrapped function instead of `fn`.""" - return super().submit(self._function_wrapper, fn, *args, **kwargs) - - def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable: - """Wraps `fn` in order to preserve the traceback of any kind of.""" - try: - return fn(*args, **kwargs) - except Exception as exc: - raise sys.exc_info()[0](traceback.format_exc()) from exc - - -class BasePlotCallback(Callback, ABC): - """Factory for creating a callback that plots data to Experiment Logging.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the BasePlotCallback abstract base class. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__() - self.config = config - self.save_basedir = config.hardware.paths.plots - self.plot_frequency = config.diagnostics.plot.frequency - self.post_processors = None - self.pre_processors = None - self.latlons = None - init_plot_settings() - - self.plot = self._plot - self._executor = None - - if self.config.diagnostics.plot.asynchronous: - self._executor = ParallelExecutor(max_workers=1) - self._error: BaseException | None = None - self.plot = self._async_plot - - @rank_zero_only - def _output_figure( - self, - logger: pl.loggers.base.LightningLoggerBase, - fig: plt.Figure, - epoch: int, - tag: str = "gnn", - exp_log_tag: str = "val_pred_sample", - ) -> None: - """Figure output: save to file and/or display in notebook.""" - if self.save_basedir is not None: - save_path = Path( - self.save_basedir, - "plots", - f"{tag}_epoch{epoch:03d}.png", - ) - - save_path.parent.mkdir(parents=True, exist_ok=True) - fig.savefig(save_path, dpi=100, bbox_inches="tight") - if self.config.diagnostics.log.wandb.enabled: - import wandb - - logger.experiment.log({exp_log_tag: wandb.Image(fig)}) - - if self.config.diagnostics.log.mlflow.enabled: - run_id = logger.run_id - logger.experiment.log_artifact(run_id, str(save_path)) - - plt.close(fig) # cleanup - - def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: - """Method is called to close the threads.""" - del trainer, pl_module, stage # unused - if self._executor is not None: - self._executor.shutdown(wait=True) - - def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor: - if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None: - # Fill with NaNs values where the mask is False - data[:, :, ~pl_module.output_mask, :] = np.nan - - return data - - @abstractmethod - @rank_zero_only - def _plot( - *args: list, - **kwargs: dict, - ) -> None: ... - - @rank_zero_only - def _async_plot( - self, - trainer: pl.Trainer, - *args: list, - **kwargs: dict, - ) -> None: - """To execute the plot function but ensuring we catch any errors.""" - future = self._executor.submit( - self._plot, - trainer, - *args, - **kwargs, - ) - # otherwise the error won't be thrown till the validation epoch is finished - try: - future.result() - except Exception: - LOGGER.exception("Critical error occurred in asynchronous plots.") - sys.exit(1) - - -class RolloutEval(Callback): - """Evaluates the model performance over a (longer) rollout window.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialize RolloutEval callback. - - Parameters - ---------- - config : dict - Dictionary with configuration settings - - """ - super().__init__() - - LOGGER.debug( - "Setting up RolloutEval callback with rollout = %d, frequency = %d ...", - config.diagnostics.eval.rollout, - config.diagnostics.eval.frequency, - ) - self.rollout = config.diagnostics.eval.rollout - self.frequency = config.diagnostics.eval.frequency - - def _eval( - self, - pl_module: pl.LightningModule, - batch: torch.Tensor, - ) -> None: - loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False) - metrics = {} - - # start rollout - batch = pl_module.model.pre_processors(batch, in_place=False) - x = batch[ - :, - 0 : pl_module.multi_step, - ..., - pl_module.data_indices.internal_data.input.full, - ] # (bs, multi_step, latlon, nvar) - assert ( - batch.shape[1] >= self.rollout + pl_module.multi_step - ), "Batch length not sufficient for requested rollout length!" - - with torch.no_grad(): - for rollout_step in range(self.rollout): - y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) - y = batch[ - :, - pl_module.multi_step + rollout_step, - ..., - pl_module.data_indices.internal_data.output.full, - ] # target, shape = (bs, latlon, nvar) - # y includes the auxiliary variables, so we must leave those out when computing the loss - loss += pl_module.loss(y_pred, y) - - x = pl_module.advance_input(x, y_pred, batch, rollout_step) - - metrics_next, _ = pl_module.calculate_val_metrics(y_pred, y, rollout_step) - metrics.update(metrics_next) - - # scale loss - loss *= 1.0 / self.rollout - self._log(pl_module, loss, metrics, batch.shape[0]) - - def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, bs: int) -> None: - pl_module.log( - f"val_r{self.rollout}_wmse", - loss, - on_epoch=True, - on_step=True, - prog_bar=False, - logger=pl_module.logger_enabled, - batch_size=bs, - sync_dist=False, - rank_zero_only=True, - ) - for mname, mvalue in metrics.items(): - pl_module.log( - f"val_r{self.rollout}_" + mname, - mvalue, - on_epoch=True, - on_step=False, - prog_bar=False, - logger=pl_module.logger_enabled, - batch_size=bs, - sync_dist=False, - rank_zero_only=True, - ) - - @rank_zero_only - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list, - batch: torch.Tensor, - batch_idx: int, - ) -> None: - del outputs # outputs are not used - if batch_idx % self.frequency == 0: - precision_mapping = { - "16-mixed": torch.float16, - "bf16-mixed": torch.bfloat16, - } - prec = trainer.precision - dtype = precision_mapping.get(prec) - context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() - - with context: - self._eval(pl_module, batch) - - -class LongRolloutPlots(BasePlotCallback): - """Evaluates the model performance over a (longer) rollout window.""" - - def __init__(self, config) -> None: - """Initialize RolloutEval callback. - - Parameters - ---------- - config : dict - Dictionary with configuration settings - """ - super().__init__(config) - - LOGGER.debug( - "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", - config.diagnostics.plot.longrollout.rollout, - config.diagnostics.plot.longrollout.frequency, - ) - self.rollout = config.diagnostics.plot.longrollout.rollout - self.eval_frequency = config.diagnostics.plot.longrollout.frequency - self.sample_idx = self.config.diagnostics.plot.sample_idx - - @rank_zero_only - def _plot( - self, - trainer, - pl_module: pl.LightningModule, - batch: torch.Tensor, - batch_idx, - epoch, - ) -> None: - - start_time = time.time() - - logger = trainer.logger - - # Build dictionary of inidicies and parameters to be plotted - plot_parameters_dict = { - pl_module.data_indices.model.output.name_to_index[name]: ( - name, - name not in self.config.data.get("diagnostic", []), - ) - for name in self.config.diagnostics.plot.parameters - } - - if self.post_processors is None: - # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() - if self.latlons is None: - self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - - batch = pl_module.model.pre_processors(batch, in_place=False) - # prepare input tensor for rollout from preprocessed batch - x = batch[ - :, - 0 : pl_module.multi_step, - ..., - pl_module.data_indices.internal_data.input.full, - ] # (bs, multi_step, latlon, nvar) - assert ( - batch.shape[1] >= max(self.rollout) + pl_module.multi_step - ), "Batch length not sufficient for requested rollout length!" - - # prepare input tensor for plotting - input_tensor_0 = batch[ - self.sample_idx, - pl_module.multi_step - 1, - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data_0 = self.post_processors(input_tensor_0).numpy() - - # start rollout - with torch.no_grad(): - for rollout_step in range(max(self.rollout)): - y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) - - x = pl_module.advance_input(x, y_pred, batch, rollout_step) - - if (rollout_step + 1) in self.rollout: - # prepare true output tensor for plotting - input_tensor_rollout_step = batch[ - self.sample_idx, - pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy() - - # prepare predicted output tensor for plotting - output_tensor = self.post_processors( - y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu() - ).numpy() - - fig = plot_predicted_multilevel_flat_sample( - plot_parameters_dict, - self.config.diagnostics.plot.per_sample, - self.latlons, - self.config.diagnostics.plot.get("accumulation_levels_plot", None), - self.config.diagnostics.plot.get("cmap_accumulation", None), - data_0.squeeze(), - data_rollout_step.squeeze(), - output_tensor[0, 0, :, :], # rolloutstep, first member - # force_global_view=self.show_entire_globe, - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_sample_rstep{rollout_step:03d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_sample_rstep{rollout_step:03d}_rank{local_rank:01d}", - ) - LOGGER.info(f"Time taken to plot samples after longer rollout: {int(time.time() - start_time)} seconds") - - @rank_zero_only - def on_validation_batch_end(self, trainer, pl_module, output, batch, batch_idx) -> None: - if (batch_idx) % self.plot_frequency == 0 and (trainer.current_epoch + 1) % self.eval_frequency == 0: - precision_mapping = { - "16-mixed": torch.float16, - "bf16-mixed": torch.bfloat16, - } - prec = trainer.precision - dtype = precision_mapping.get(prec) - context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() - - with context: - self._plot(trainer, pl_module, batch, batch_idx, epoch=trainer.current_epoch) - - -class GraphTrainableFeaturesPlot(BasePlotCallback): - """Visualize the trainable features defined at the data and hidden graph nodes. - - TODO: How best to visualize the learned edge embeddings? Offline, perhaps - using code from @Simon's notebook? - """ - - def __init__(self, config: OmegaConf) -> None: - """Initialise the GraphTrainableFeaturesPlot callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__(config) - self._graph_name_data = config.graph.data - self._graph_name_hidden = config.graph.hidden - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - latlons: np.ndarray, - features: np.ndarray, - epoch: int, - tag: str, - exp_log_tag: str, - ) -> None: - fig = plot_graph_features(latlons, features) - self._output_figure(trainer.logger, fig, epoch=epoch, tag=tag, exp_log_tag=exp_log_tag) - - @rank_zero_only - def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - - model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model - graph = pl_module.graph_data.cpu().detach() - epoch = trainer.current_epoch - - if model.trainable_data is not None: - data_coords = np.rad2deg(graph[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.numpy()) - - self.plot( - trainer, - data_coords, - model.trainable_data.trainable.cpu().detach().numpy(), - epoch=epoch, - tag="trainable_data", - exp_log_tag="trainable_data", - ) - - if model.trainable_hidden is not None: - hidden_coords = np.rad2deg( - graph[(self._graph_name_hidden, "to", self._graph_name_hidden)].hcoords_rad.numpy(), - ) - - self.plot( - trainer, - hidden_coords, - model.trainable_hidden.trainable.cpu().detach().numpy(), - epoch=epoch, - tag="trainable_hidden", - exp_log_tag="trainable_hidden", - ) - - -class PlotLoss(BasePlotCallback): - """Plots the unsqueezed loss over rollouts.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the PlotLoss callback. - - Parameters - ---------- - config : OmegaConf - Object with configuration settings - - """ - super().__init__(config) - self.parameter_names = None - self.parameter_groups = self.config.diagnostics.plot.parameter_groups - if self.parameter_groups is None: - self.parameter_groups = {} - - @cached_property - def sort_and_color_by_parameter_group(self) -> tuple[np.ndarray, np.ndarray, dict, list]: - """Sort parameters by group and prepare colors.""" - - def automatically_determine_group(name: str) -> str: - # first prefix of parameter name is group name - parts = name.split("_") - return parts[0] - - # group parameters by their determined group name for > 15 parameters - if len(self.parameter_names) <= 15: - # for <= 15 parameters, keep the full name of parameters - parameters_to_groups = np.array(self.parameter_names) - sort_by_parameter_group = np.arange(len(self.parameter_names), dtype=int) - else: - parameters_to_groups = np.array( - [ - next( - ( - group_name - for group_name, group_parameters in self.parameter_groups.items() - if name in group_parameters - ), - automatically_determine_group(name), - ) - for name in self.parameter_names - ], - ) - - unique_group_list, group_inverse, group_counts = np.unique( - parameters_to_groups, - return_inverse=True, - return_counts=True, - ) - - # join parameter groups that appear only once and are not given in config-file - unique_group_list = np.array( - [ - unique_group_list[tn] if count > 1 or unique_group_list[tn] in self.parameter_groups else "other" - for tn, count in enumerate(group_counts) - ], - ) - parameters_to_groups = unique_group_list[group_inverse] - unique_group_list, group_inverse = np.unique(parameters_to_groups, return_inverse=True) - - # sort parameters by groups - sort_by_parameter_group = np.argsort(group_inverse, kind="stable") - - # apply new order to parameters - sorted_parameter_names = np.array(self.parameter_names)[sort_by_parameter_group] - parameters_to_groups = parameters_to_groups[sort_by_parameter_group] - unique_group_list, group_inverse, group_counts = np.unique( - parameters_to_groups, - return_inverse=True, - return_counts=True, - ) - - # get a color per group and project to parameter list - cmap = "tab10" if len(unique_group_list) <= 10 else "tab20" - if len(unique_group_list) > 20: - LOGGER.warning("More than 20 groups detected, but colormap has only 20 colors.") - # if all groups have count 1 use black color - bar_color_per_group = ( - np.tile("k", len(group_counts)) - if not np.any(group_counts - 1) - else plt.get_cmap(cmap)(np.linspace(0, 1, len(unique_group_list))) - ) - - # set x-ticks - x_tick_positions = np.cumsum(group_counts) - group_counts / 2 - 0.5 - xticks = dict(zip(unique_group_list, x_tick_positions)) - - legend_patches = [] - for group_idx, group in enumerate(unique_group_list): - text_label = f"{group}: " - string_length = len(text_label) - for ii in np.where(group_inverse == group_idx)[0]: - text_label += sorted_parameter_names[ii] + ", " - string_length += len(sorted_parameter_names[ii]) + 2 - if string_length > 50: - # linebreak after 50 characters - text_label += "\n" - string_length = 0 - legend_patches.append(mpatches.Patch(color=bar_color_per_group[group_idx], label=text_label[:-2])) - - return sort_by_parameter_group, bar_color_per_group[group_inverse], xticks, legend_patches - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - pl_module: pl.Lightning_module, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - epoch: int, - ) -> None: - logger = trainer.logger - - parameter_names = list(pl_module.data_indices.internal_model.output.name_to_index.keys()) - parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values()) - # reorder parameter_names by position - self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] - - batch = pl_module.model.pre_processors(batch, in_place=False) - for rollout_step in range(pl_module.rollout): - y_hat = outputs[1][rollout_step] - y_true = batch[ - :, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.internal_data.output.full - ] - loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() - - sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group - fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"loss_rstep_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", - exp_log_tag=f"loss_sample_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", - ) - - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) - - -class PlotSample(BasePlotCallback): - """Plots a post-processed sample: input, target and prediction.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the PlotSample callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__(config) - self.sample_idx = self.config.diagnostics.plot.sample_idx - self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields - LOGGER.info(f"Using defined accumulation colormap for fields: {self.precip_and_related_fields}") - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - pl_module: pl.Lightning_module, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - epoch: int, - ) -> None: - logger = trainer.logger - - # Build dictionary of indices and parameters to be plotted - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - plot_parameters_dict = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters - } - - # When running in Async mode, it might happen that in the last epoch these tensors - # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA - # but internal ones would be on the cpu), The lines below allow to address this problem - if self.post_processors is None: - # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() - if self.latlons is None: - self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - - batch = pl_module.model.pre_processors(batch, in_place=False) - input_tensor = batch[ - self.sample_idx, - pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data = self.post_processors(input_tensor) - - output_tensor = self.post_processors( - torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), - in_place=False, - ) - - output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() - data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) - data = data.numpy() - - for rollout_step in range(pl_module.rollout): - fig = plot_predicted_multilevel_flat_sample( - plot_parameters_dict, - self.config.diagnostics.plot.per_sample, - self.latlons, - self.config.diagnostics.plot.accumulation_levels_plot, - self.config.diagnostics.plot.cmap_accumulation, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - precip_and_related_fields=self.precip_and_related_fields, - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", - ) - - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.Lightning_module, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) - +def nestedget(conf: DictConfig, key: str, default: Any) -> Any: + """Get a nested key from a DictConfig object. -class PlotAdditionalMetrics(BasePlotCallback): - """Plots TP related metric comparing target and prediction. - - The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. - - - Power Spectrum - - Histograms + E.g. + >>> nestedget(config, "diagnostics.log.wandb.enabled", False) """ + keys = key.split(".") + for k in keys: + conf = conf.get(k, default) + if not isinstance(conf, (dict, DictConfig)): + break + return conf + + +# Callbacks to add according to flags in the config +# Can be function to check status from config +CONFIG_ENABLED_CALLBACKS: list[tuple[list[str] | str | Callable[[DictConfig], bool], type[Callback]]] = [ + ("training.swa.enabled", StochasticWeightAveraging), + ( + lambda config: nestedget(config, "diagnostics.log.wandb.enabled", False) + or nestedget(config, "diagnostics.log.mlflow.enabled", False), + LearningRateMonitor, + ), +] + + +def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint]: + """Get checkpointing callbacks.""" + if not config.diagnostics.get("enable_checkpointing", True): + return [] - def __init__(self, config: OmegaConf) -> None: - """Initialise the PlotAdditionalMetrics callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__(config) - self.sample_idx = self.config.diagnostics.plot.sample_idx - self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields - LOGGER.info(f"Using precip histogram plotting method for fields: {self.precip_and_related_fields}") - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list, - batch: torch.Tensor, - batch_idx: int, - epoch: int, - ) -> None: - logger = trainer.logger - - # When running in Async mode, it might happen that in the last epoch these tensors - # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA - # but internal ones would be on the cpu), The lines below allow to address this problem - if self.pre_processors is None: - # Copy to be used across all the training cycle - self.pre_processors = copy.deepcopy(pl_module.model.pre_processors).cpu() - if self.post_processors is None: - # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() - if self.latlons is None: - self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - batch = pl_module.model.pre_processors(batch, in_place=False) - input_tensor = batch[ - self.sample_idx, - pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data = self.post_processors(input_tensor) - output_tensor = self.post_processors( - torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), - in_place=False, - ) - - output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() - data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) - data = data.numpy() - - for rollout_step in range(pl_module.rollout): - if self.config.diagnostics.plot.parameters_histogram is not None: - # Build dictionary of inidicies and parameters to be plotted - - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - plot_parameters_dict_histogram = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters_histogram - } - - fig = plot_histogram( - plot_parameters_dict_histogram, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - precip_and_related_fields=self.precip_and_related_fields, - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", - ) - - if self.config.diagnostics.plot.parameters_spectrum is not None: - # Build dictionary of inidicies and parameters to be plotted - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - - plot_parameters_dict_spectrum = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters_spectrum - } - - fig = plot_power_spectrum( - plot_parameters_dict_spectrum, - self.latlons, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", - ) - - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) - - -class ParentUUIDCallback(Callback): - """A callback that retrieves the parent UUID for a model, if it is a child model.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the ParentUUIDCallback callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__() - self.config = config - - def on_load_checkpoint( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - checkpoint: torch.nn.Module, - ) -> None: - del trainer # unused - pl_module.hparams["metadata"]["parent_uuid"] = checkpoint["hyper_parameters"]["metadata"]["uuid"] - - -class AnemoiCheckpoint(ModelCheckpoint): - """A checkpoint callback that saves the model after every validation epoch.""" - - def __init__(self, config: OmegaConf, **kwargs: dict) -> None: - """Initialise the AnemoiCheckpoint callback. - - Parameters - ---------- - config : OmegaConf - Config object - kwargs : dict - Additional keyword arguments for Pytorch ModelCheckpoint - - """ - super().__init__(**kwargs) - self.config = config - self.start = time.time() - self._model_metadata = None - self._tracker_metadata = None - self._tracker_name = None - - @staticmethod - def _torch_drop_down(trainer: pl.Trainer) -> torch.nn.Module: - # Get the model from the DataParallel wrapper, for single and multi-gpu cases - assert hasattr(trainer, "model"), "Trainer has no attribute 'model'! Is the Pytorch Lightning version correct?" - return trainer.model.module.model if hasattr(trainer.model, "module") else trainer.model.model - - @rank_zero_only - def model_metadata(self, model: torch.nn.Module) -> dict: - if self._model_metadata is not None: - return self._model_metadata - - self._model_metadata = { - "model": model.__class__.__name__, - "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad), - "total_parameters": sum(p.numel() for p in model.parameters()), - "summary": repr( - torchinfo.summary( - model, - depth=50, - verbose=0, - row_settings=["var_names"], - ), - ), - } - - return self._model_metadata - - def tracker_metadata(self, trainer: pl.Trainer) -> dict: - if self._tracker_metadata is not None: - return {self._tracker_name: self._tracker_metadata} - - if self.config.diagnostics.log.wandb.enabled: - self._tracker_name = "wand" - import wandb - - run = wandb.run - if run is not None: - self._tracker_metadata = { - "id": run.id, - "name": run.name, - "url": run.url, - "project": run.project, - } - return {self._tracker_name: self._tracker_metadata} - - if self.config.diagnostics.log.mlflow.enabled: - self._tracker_name = "mlflow" - - from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger - - mlflow_logger = next(logger for logger in trainer.loggers if isinstance(logger, AnemoiMLflowLogger)) - run_id = mlflow_logger.run_id - run = mlflow_logger._mlflow_client.get_run(run_id) - - if run is not None: - self._tracker_metadata = { - "id": run.info.run_id, - "name": run.info.run_name, - "url": run.info.artifact_uri, - "project": run.info.experiment_id, - } - return {self._tracker_name: self._tracker_metadata} - - return {} - - def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: - """Calls the strategy to remove the checkpoint file.""" - super()._remove_checkpoint(trainer, filepath) - trainer.strategy.remove_checkpoint(self._get_inference_checkpoint_filepath(filepath)) - - def _get_inference_checkpoint_filepath(self, filepath: str) -> str: - """Defines the filepath for the inference checkpoint.""" - return Path(filepath).parent / Path("inference-" + str(Path(filepath).name)) - - def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: str) -> None: - if trainer.is_global_zero: - model = self._torch_drop_down(trainer) - - # We want a different uuid each time we save the model - # so we can tell them apart in the catalogue (i.e. different epochs) - checkpoint_uuid = str(uuid.uuid4()) - trainer.lightning_module._hparams["metadata"]["uuid"] = checkpoint_uuid - - trainer.lightning_module._hparams["metadata"]["model"] = self.model_metadata(model) - trainer.lightning_module._hparams["metadata"]["tracker"] = self.tracker_metadata(trainer) - - trainer.lightning_module._hparams["metadata"]["training"] = { - "current_epoch": trainer.current_epoch, - "global_step": trainer.global_step, - "elapsed_time": time.time() - self.start, - } - - Path(lightning_checkpoint_filepath).parent.mkdir(parents=True, exist_ok=True) - - save_config = model.config - model.config = None - - tmp_metadata = model.metadata - model.metadata = None - - metadata = dict(**tmp_metadata) - - inference_checkpoint_filepath = self._get_inference_checkpoint_filepath(lightning_checkpoint_filepath) - - torch.save(model, inference_checkpoint_filepath) - - save_metadata(inference_checkpoint_filepath, metadata) - - model.config = save_config - model.metadata = tmp_metadata - - self._last_global_step_saved = trainer.global_step - - trainer.strategy.barrier() - - # saving checkpoint used for pytorch-lightning based training - trainer.save_checkpoint(lightning_checkpoint_filepath, self.save_weights_only) - - self._last_global_step_saved = trainer.global_step - self._last_checkpoint_saved = lightning_checkpoint_filepath - - if trainer.is_global_zero: - from weakref import proxy - - # save metadata for the training checkpoint in the same format as inference - save_metadata(lightning_checkpoint_filepath, metadata) - - # notify loggers - for logger in trainer.loggers: - logger.after_save_checkpoint(proxy(self)) - - -def get_callbacks(config: DictConfig) -> list: # noqa: C901 - """Setup callbacks for PyTorch Lightning trainer. - - Parameters - ---------- - config : DictConfig - Job configuration - - Returns - ------- - List - A list of PyTorch Lightning callbacks - - """ checkpoint_settings = { "dirpath": config.hardware.paths.checkpoints, "verbose": False, @@ -1065,6 +74,7 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901 } ckpt_frequency_save_dict = {} + for key, frequency_dict in config.diagnostics.checkpoint.items(): frequency = frequency_dict["save_frequency"] n_saved = frequency_dict["num_models_saved"] @@ -1073,81 +83,122 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901 frequency = timedelta(minutes=frequency_dict["save_frequency"]) else: target = key - ckpt_frequency_save_dict[target] = (config.hardware.files.checkpoint[key], frequency, n_saved) + ckpt_frequency_save_dict[target] = ( + config.hardware.files.checkpoint[key], + frequency, + n_saved, + ) - trainer_callbacks = [] + checkpoint_callbacks = [] if not config.diagnostics.profiler: - for save_key, (name, save_frequency, save_n_models) in ckpt_frequency_save_dict.items(): + for save_key, ( + name, + save_frequency, + save_n_models, + ) in ckpt_frequency_save_dict.items(): if save_frequency is not None: LOGGER.debug("Checkpoint callback at %s = %s ...", save_key, save_frequency) - trainer_callbacks.extend( + checkpoint_callbacks.append( # save_top_k: the save_top_k flag can either save the best or the last k checkpoints # depending on the monitor flag on ModelCheckpoint. # See https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html for reference - [ - AnemoiCheckpoint( - config=config, - filename=name, - save_last=True, - **{save_key: save_frequency}, - # if save_top_k == k, last k models saved; if save_top_k == -1, all models are saved - save_top_k=save_n_models, - monitor="step", - mode="max", - **checkpoint_settings, - ), - ], + AnemoiCheckpoint( + config=config, + filename=name, + save_last=True, + **{save_key: save_frequency}, + # if save_top_k == k, last k models saved; if save_top_k == -1, all models are saved + save_top_k=save_n_models, + monitor="step", + mode="max", + **checkpoint_settings, + ), ) - else: - LOGGER.debug("Not setting up a checkpoint callback with %s", save_key) + LOGGER.debug("Not setting up a checkpoint callback with %s", save_key) else: # the tensorboard logger + pytorch profiler cause pickling errors when writing checkpoints LOGGER.warning("Profiling is enabled - will not write any training or inference model checkpoints!") + return checkpoint_callbacks - if any([config.diagnostics.log.wandb.enabled, config.diagnostics.log.mlflow.enabled]): - from pytorch_lightning.callbacks import LearningRateMonitor - trainer_callbacks.append( - LearningRateMonitor( - logging_interval="step", - log_momentum=False, - ), - ) +def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: + """Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS.""" + callbacks = [] - if config.diagnostics.eval.enabled: - trainer_callbacks.append(RolloutEval(config)) + def check_key(config: dict, key: str | Iterable[str] | Callable[[DictConfig], bool]) -> bool: + """Check key in config.""" + if isinstance(key, Callable): + return key(config) + if isinstance(key, str): + return nestedget(config, key, False) + if isinstance(key, Iterable): + return all(nestedget(config, k, False) for k in key) + return nestedget(config, key, False) - if config.diagnostics.plot.enabled: - trainer_callbacks.extend( - [ - PlotLoss(config), - PlotSample(config), - ], - ) - if (config.diagnostics.plot.parameters_histogram or config.diagnostics.plot.parameters_spectrum) is not None: - trainer_callbacks.extend([PlotAdditionalMetrics(config)]) - if config.diagnostics.plot.get("longrollout") and config.diagnostics.plot.longrollout.enabled: - trainer_callbacks.extend([LongRolloutPlots(config)]) + for enable_key, callback_list in CONFIG_ENABLED_CALLBACKS: + if check_key(config, enable_key): + callbacks.append(callback_list(config)) - if config.training.swa.enabled: - from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging + return callbacks - trainer_callbacks.append( - StochasticWeightAveraging( - swa_lrs=config.training.swa.lr, - swa_epoch_start=min( - int(0.75 * config.training.max_epochs), - config.training.max_epochs - 1, - ), - annealing_epochs=max(int(0.25 * config.training.max_epochs), 1), - annealing_strategy="cos", - device=None, - ), - ) +def get_callbacks(config: DictConfig) -> list[Callback]: + """Setup callbacks for PyTorch Lightning trainer. + + Set `config.diagnostics.callbacks` to a list of callback configurations + in hydra form. + + E.g.: + ``` + callbacks: + - _target_: anemoi.training.diagnostics.callbacks.RolloutEval + rollout: 1 + frequency: 12 + ``` + + Set `config.diagnostics.plot.callbacks` to a list of plot callback configurations + will only be added if `config.diagnostics.plot.enabled` is set to True. + + A callback must take a `DictConfig` in its `__init__` method as the first argument, + which will be the complete configuration object. + + Some callbacks are added by default, depending on the configuration. + See CONFIG_ENABLED_CALLBACKS for more information. + + Parameters + ---------- + config : DictConfig + Job configuration + + Returns + ------- + List[Callback] + A list of PyTorch Lightning callbacks + + """ + trainer_callbacks: list[Callback] = [] + + # Get Checkpoint callback + trainer_callbacks.extend(_get_checkpoint_callback(config)) + + # Base callbacks + trainer_callbacks.extend( + instantiate(callback, config) for callback in config.diagnostics.get("callbacks", None) or [] + ) + + # Plotting callbacks + + trainer_callbacks.extend( + instantiate(callback, config) for callback in config.diagnostics.plot.get("callbacks", None) or [] + ) + + # Extend with config enabled callbacks + trainer_callbacks.extend(_get_config_enabled_callbacks(config)) + + # Parent UUID callback trainer_callbacks.append(ParentUUIDCallback(config)) - if config.diagnostics.plot.learned_features: - LOGGER.debug("Setting up a callback to plot the trainable graph node features ...") - trainer_callbacks.append(GraphTrainableFeaturesPlot(config)) return trainer_callbacks + + +__all__ = ["get_callbacks"] diff --git a/src/anemoi/training/diagnostics/callbacks/checkpoint.py b/src/anemoi/training/diagnostics/callbacks/checkpoint.py new file mode 100644 index 00000000..cb95f5a4 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/checkpoint.py @@ -0,0 +1,181 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +import time +import uuid +from pathlib import Path +from typing import TYPE_CHECKING + +import torch +import torchinfo +from anemoi.utils.checkpoints import save_metadata +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.utilities import rank_zero_only + +if TYPE_CHECKING: + import pytorch_lightning as pl + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class AnemoiCheckpoint(ModelCheckpoint): + """A checkpoint callback that saves the model after every validation epoch.""" + + def __init__(self, config: OmegaConf, **kwargs: dict) -> None: + """Initialise the AnemoiCheckpoint callback. + + Parameters + ---------- + config : OmegaConf + Config object + kwargs : dict + Additional keyword arguments for Pytorch ModelCheckpoint + + """ + super().__init__(**kwargs) + self.config = config + self.start = time.time() + self._model_metadata = None + self._tracker_metadata = None + self._tracker_name = None + + @staticmethod + def _torch_drop_down(trainer: pl.Trainer) -> torch.nn.Module: + # Get the model from the DataParallel wrapper, for single and multi-gpu cases + assert hasattr(trainer, "model"), "Trainer has no attribute 'model'! Is the Pytorch Lightning version correct?" + return trainer.model.module.model if hasattr(trainer.model, "module") else trainer.model.model + + @rank_zero_only + def model_metadata(self, model: torch.nn.Module) -> dict: + if self._model_metadata is not None: + return self._model_metadata + + self._model_metadata = { + "model": model.__class__.__name__, + "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad), + "total_parameters": sum(p.numel() for p in model.parameters()), + "summary": repr( + torchinfo.summary( + model, + depth=50, + verbose=0, + row_settings=["var_names"], + ), + ), + } + + return self._model_metadata + + def tracker_metadata(self, trainer: pl.Trainer) -> dict: + if self._tracker_metadata is not None: + return {self._tracker_name: self._tracker_metadata} + + if self.config.diagnostics.log.wandb.enabled: + self._tracker_name = "wand" + import wandb + + run = wandb.run + if run is not None: + self._tracker_metadata = { + "id": run.id, + "name": run.name, + "url": run.url, + "project": run.project, + } + return {self._tracker_name: self._tracker_metadata} + + if self.config.diagnostics.log.mlflow.enabled: + self._tracker_name = "mlflow" + + from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger + + mlflow_logger = next(logger for logger in trainer.loggers if isinstance(logger, AnemoiMLflowLogger)) + run_id = mlflow_logger.run_id + run = mlflow_logger._mlflow_client.get_run(run_id) + + if run is not None: + self._tracker_metadata = { + "id": run.info.run_id, + "name": run.info.run_name, + "url": run.info.artifact_uri, + "project": run.info.experiment_id, + } + return {self._tracker_name: self._tracker_metadata} + + return {} + + def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: + """Calls the strategy to remove the checkpoint file.""" + super()._remove_checkpoint(trainer, filepath) + trainer.strategy.remove_checkpoint(self._get_inference_checkpoint_filepath(filepath)) + + def _get_inference_checkpoint_filepath(self, filepath: str) -> str: + """Defines the filepath for the inference checkpoint.""" + return Path(filepath).parent / Path("inference-" + str(Path(filepath).name)) + + def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: str) -> None: + if trainer.is_global_zero: + model = self._torch_drop_down(trainer) + + # We want a different uuid each time we save the model + # so we can tell them apart in the catalogue (i.e. different epochs) + checkpoint_uuid = str(uuid.uuid4()) + trainer.lightning_module._hparams["metadata"]["uuid"] = checkpoint_uuid + + trainer.lightning_module._hparams["metadata"]["model"] = self.model_metadata(model) + trainer.lightning_module._hparams["metadata"]["tracker"] = self.tracker_metadata(trainer) + + trainer.lightning_module._hparams["metadata"]["training"] = { + "current_epoch": trainer.current_epoch, + "global_step": trainer.global_step, + "elapsed_time": time.time() - self.start, + } + + Path(lightning_checkpoint_filepath).parent.mkdir(parents=True, exist_ok=True) + + save_config = model.config + model.config = None + + tmp_metadata = model.metadata + model.metadata = None + + metadata = dict(**tmp_metadata) + + inference_checkpoint_filepath = self._get_inference_checkpoint_filepath(lightning_checkpoint_filepath) + + torch.save(model, inference_checkpoint_filepath) + + save_metadata(inference_checkpoint_filepath, metadata) + + model.config = save_config + model.metadata = tmp_metadata + + self._last_global_step_saved = trainer.global_step + + trainer.strategy.barrier() + + # saving checkpoint used for pytorch-lightning based training + trainer.save_checkpoint(lightning_checkpoint_filepath, self.save_weights_only) + + self._last_global_step_saved = trainer.global_step + self._last_checkpoint_saved = lightning_checkpoint_filepath + + if trainer.is_global_zero: + from weakref import proxy + + # save metadata for the training checkpoint in the same format as inference + save_metadata(lightning_checkpoint_filepath, metadata) + + # notify loggers + for logger in trainer.loggers: + logger.after_save_checkpoint(proxy(self)) diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py new file mode 100644 index 00000000..fc812121 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -0,0 +1,126 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import torch +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities import rank_zero_only + +if TYPE_CHECKING: + import pytorch_lightning as pl + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class RolloutEval(Callback): + """Evaluates the model performance over a (longer) rollout window.""" + + def __init__(self, config: OmegaConf, rollout: int, every_n_batches: int) -> None: + """Initialize RolloutEval callback. + + Parameters + ---------- + config : dict + Dictionary with configuration settings + rollout : int + Rollout length for evaluation + every_n_batches : int + Frequency of rollout evaluation, runs every `n` validation batches + + """ + super().__init__() + self.config = config + + LOGGER.debug( + "Setting up RolloutEval callback with rollout = %d, every_n_batches = %d ...", + rollout, + every_n_batches, + ) + self.rollout = rollout + self.every_n_batches = every_n_batches + + def _eval( + self, + pl_module: pl.LightningModule, + batch: torch.Tensor, + ) -> None: + loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False) + metrics = {} + + assert batch.shape[1] >= self.rollout + pl_module.multi_step, ( + "Batch length not sufficient for requested validation rollout length! " + f"Set `dataloader.validation_rollout` to at least {max(self.rollout)}" + ) + + with torch.no_grad(): + for loss_next, metrics_next, _ in pl_module.rollout_step( + batch, + rollout=self.rollout, + validation_mode=True, + training_mode=True, + ): + loss += loss_next + metrics.update(metrics_next) + + # scale loss + loss *= 1.0 / self.rollout + self._log(pl_module, loss, metrics, batch.shape[0]) + + def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, bs: int) -> None: + pl_module.log( + f"val_r{self.rollout}_{getattr(pl_module.loss, 'name', pl_module.loss.__class__.__name__.lower())}", + loss, + on_epoch=True, + on_step=True, + prog_bar=False, + logger=pl_module.logger_enabled, + batch_size=bs, + sync_dist=False, + rank_zero_only=True, + ) + for mname, mvalue in metrics.items(): + pl_module.log( + f"val_r{self.rollout}_" + mname, + mvalue, + on_epoch=True, + on_step=False, + prog_bar=False, + logger=pl_module.logger_enabled, + batch_size=bs, + sync_dist=False, + rank_zero_only=True, + ) + + @rank_zero_only + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list, + batch: torch.Tensor, + batch_idx: int, + ) -> None: + del outputs # outputs are not used + if batch_idx % self.every_n_batches == 0: + precision_mapping = { + "16-mixed": torch.float16, + "bf16-mixed": torch.bfloat16, + } + prec = trainer.precision + dtype = precision_mapping.get(prec) + context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() + + with context: + self._eval(pl_module, batch) diff --git a/src/anemoi/training/diagnostics/callbacks/optimiser.py b/src/anemoi/training/diagnostics/callbacks/optimiser.py new file mode 100644 index 00000000..bff82fcb --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/optimiser.py @@ -0,0 +1,77 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from pytorch_lightning.callbacks import LearningRateMonitor as pl_LearningRateMonitor +from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging as pl_StochasticWeightAveraging + +LOGGER = logging.getLogger(__name__) + +if TYPE_CHECKING: + from omegaconf import DictConfig + + +class LearningRateMonitor(pl_LearningRateMonitor): + """Provide LearningRateMonitor from pytorch_lightning as a callback.""" + + def __init__( + self, + config: DictConfig, + logging_interval: str = "step", + log_momentum: bool = False, + ) -> None: + super().__init__(logging_interval=logging_interval, log_momentum=log_momentum) + self.config = config + + +class StochasticWeightAveraging(pl_StochasticWeightAveraging): + """Provide StochasticWeightAveraging from pytorch_lightning as a callback.""" + + def __init__( + self, + config: DictConfig, + swa_lrs: int | None = None, + swa_epoch_start: int | None = None, + annealing_epoch: int | None = None, + annealing_strategy: str | None = None, + device: str | None = None, + **kwargs, + ) -> None: + """Stochastic Weight Averaging Callback. + + Parameters + ---------- + config : OmegaConf + Full configuration object + swa_lrs : int, optional + Stochastic Weight Averaging Learning Rate, by default None + swa_epoch_start : int, optional + Epoch start, by default 0.75 * config.training.max_epochs + annealing_epoch : int, optional + Annealing Epoch, by default 0.25 * config.training.max_epochs + annealing_strategy : str, optional + Annealing Strategy, by default 'cos' + device : str, optional + Device to use, by default None + """ + kwargs["swa_lrs"] = swa_lrs or config.training.swa.lr + kwargs["swa_epoch_start"] = swa_epoch_start or min( + int(0.75 * config.training.max_epochs), + config.training.max_epochs - 1, + ) + kwargs["annealing_epoch"] = annealing_epoch or max(int(0.25 * config.training.max_epochs), 1) + kwargs["annealing_strategy"] = annealing_strategy or "cos" + kwargs["device"] = device + + super().__init__(**kwargs) + self.config = config diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py new file mode 100644 index 00000000..869a69fb --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -0,0 +1,961 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# ruff: noqa: ANN001 + +from __future__ import annotations + +import copy +import logging +import sys +import time +import traceback +from abc import ABC +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +import torch +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities import rank_zero_only + +from anemoi.training.diagnostics.plots import init_plot_settings +from anemoi.training.diagnostics.plots import plot_graph_edge_features +from anemoi.training.diagnostics.plots import plot_graph_node_features +from anemoi.training.diagnostics.plots import plot_histogram +from anemoi.training.diagnostics.plots import plot_loss +from anemoi.training.diagnostics.plots import plot_power_spectrum +from anemoi.training.diagnostics.plots import plot_predicted_multilevel_flat_sample +from anemoi.training.losses.weightedloss import BaseWeightedLoss + +if TYPE_CHECKING: + import pytorch_lightning as pl + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class ParallelExecutor(ThreadPoolExecutor): + """Wraps parallel execution and provides accurate information about errors. + + Extends ThreadPoolExecutor to preserve the original traceback and line number. + + Reference: https://stackoverflow.com/questions/19309514/getting-original-line- + number-for-exception-in-concurrent-futures/24457608#24457608 + """ + + def submit(self, fn: Any, *args, **kwargs) -> Callable: + """Submits the wrapped function instead of `fn`.""" + return super().submit(self._function_wrapper, fn, *args, **kwargs) + + def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable: + """Wraps `fn` in order to preserve the traceback of any kind of.""" + try: + return fn(*args, **kwargs) + except Exception as exc: + raise sys.exc_info()[0](traceback.format_exc()) from exc + + +class BasePlotCallback(Callback, ABC): + """Factory for creating a callback that plots data to Experiment Logging.""" + + def __init__(self, config: OmegaConf) -> None: + """Initialise the BasePlotCallback abstract base class. + + Parameters + ---------- + config : OmegaConf + Config object + + """ + super().__init__() + self.config = config + self.save_basedir = config.hardware.paths.plots + + self.post_processors = None + self.pre_processors = None + self.latlons = None + init_plot_settings() + + self.plot = self._plot + self._executor = None + + if self.config.diagnostics.plot.asynchronous: + self._executor = ParallelExecutor(max_workers=1) + self._error: BaseException | None = None + self.plot = self._async_plot + + @rank_zero_only + def _output_figure( + self, + logger: pl.loggers.base.LightningLoggerBase, + fig: plt.Figure, + epoch: int, + tag: str = "gnn", + exp_log_tag: str = "val_pred_sample", + ) -> None: + """Figure output: save to file and/or display in notebook.""" + if self.save_basedir is not None: + save_path = Path( + self.save_basedir, + "plots", + f"{tag}_epoch{epoch:03d}.png", + ) + + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=100, bbox_inches="tight") + if self.config.diagnostics.log.wandb.enabled: + import wandb + + logger.experiment.log({exp_log_tag: wandb.Image(fig)}) + + if self.config.diagnostics.log.mlflow.enabled: + run_id = logger.run_id + logger.experiment.log_artifact(run_id, str(save_path)) + + plt.close(fig) # cleanup + + def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + """Method is called to close the threads.""" + del trainer, pl_module, stage # unused + if self._executor is not None: + self._executor.shutdown(wait=True) + + def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor: + if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None: + # Fill with NaNs values where the mask is False + data[:, :, ~pl_module.output_mask, :] = np.nan + return data + + @abstractmethod + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + *args, + **kwargs, + ) -> None: + """Plotting function to be implemented by subclasses.""" + + @rank_zero_only + def _async_plot( + self, + trainer: pl.Trainer, + *args: list, + **kwargs: dict, + ) -> None: + """To execute the plot function but ensuring we catch any errors.""" + future = self._executor.submit( + self._plot, + trainer, + *args, + **kwargs, + ) + # otherwise the error won't be thrown till the validation epoch is finished + try: + future.result() + except Exception: + LOGGER.exception("Critical error occurred in asynchronous plots.") + sys.exit(1) + + +class BasePerBatchPlotCallback(BasePlotCallback): + """Base Callback for plotting at the end of each batch.""" + + def __init__(self, config: OmegaConf, every_n_batches: int | None = None): + """Initialise the BasePerBatchPlotCallback. + + Parameters + ---------- + config : OmegaConf + Config object + every_n_batches : int, optional + Batch Frequency to plot at, by default None + If not given, uses default from config at `diagnostics.plot.frequency.batch` + + """ + super().__init__(config) + self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch + + @abstractmethod + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + epoch: int, + **kwargs, + ) -> None: + """Plotting function to be implemented by subclasses.""" + + @rank_zero_only + def on_validation_batch_end( + self, + trainer, + pl_module, + output, + batch: torch.Tensor, + batch_idx: int, + **kwargs, + ) -> None: + if batch_idx % self.every_n_batches == 0: + self.plot( + trainer, + pl_module, + output, + batch, + batch_idx, + epoch=trainer.current_epoch, + **kwargs, + ) + + +class BasePerEpochPlotCallback(BasePlotCallback): + """Base Callback for plotting at the end of each epoch.""" + + def __init__(self, config: OmegaConf, every_n_epochs: int | None = None): + """Initialise the BasePerEpochPlotCallback. + + Parameters + ---------- + config : OmegaConf + Config object + every_n_epochs : int, optional + Epoch frequency to plot at, by default None + If not given, uses default from config at `diagnostics.plot.frequency.epoch` + """ + super().__init__(config) + self.every_n_epochs = every_n_epochs or self.config.diagnostics.plot.frequency.epoch + + @rank_zero_only + def on_validation_epoch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + **kwargs, + ) -> None: + if trainer.current_epoch % self.every_n_epochs == 0: + self.plot(trainer, pl_module, epoch=trainer.current_epoch, **kwargs) + + +class LongRolloutPlots(BasePlotCallback): + """Evaluates the model performance over a (longer) rollout window.""" + + def __init__( + self, + config: OmegaConf, + rollout: list[int], + sample_idx: int, + parameters: list[str], + accumulation_levels_plot: list[float] | None = None, + cmap_accumulation: list[str] | None = None, + per_sample: int = 6, + every_n_epochs: int = 1, + ) -> None: + """Initialise LongRolloutPlots callback. + + Parameters + ---------- + config : OmegaConf + Config object + rollout : list[int] + Rollout steps to plot at + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + accumulation_levels_plot : list[float] | None + Accumulation levels to plot, by default None + cmap_accumulation : list[str] | None + Colors of the accumulation levels, by default None + per_sample : int, optional + Number of plots per sample, by default 6 + every_n_epochs : int, optional + Epoch frequency to plot at, by default 1 + """ + super().__init__(config) + + self.every_n_epochs = every_n_epochs + + LOGGER.debug( + "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", + rollout, + every_n_epochs, + ) + self.rollout = rollout + self.sample_idx = sample_idx + self.accumulation_levels_plot = accumulation_levels_plot + self.cmap_accumulation = cmap_accumulation + self.per_sample = per_sample + self.parameters = parameters + + @rank_zero_only + def _plot( + self, + trainer, + pl_module: pl.LightningModule, + output: list[torch.Tensor], + batch: torch.Tensor, + batch_idx, + epoch, + ) -> None: + _ = output + + start_time = time.time() + + logger = trainer.logger + + # Build dictionary of inidicies and parameters to be plotted + plot_parameters_dict = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in self.config.data.get("diagnostic", []), + ) + for name in self.parameters + } + + if self.post_processors is None: + # Copy to be used across all the training cycle + self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + local_rank = pl_module.local_rank + + assert batch.shape[1] >= max(self.rollout) + pl_module.multi_step, ( + "Batch length not sufficient for requested validation rollout length! " + f"Set `dataloader.validation_rollout` to at least {max(self.rollout)}" + ) + + # prepare input tensor for plotting + input_batch = pl_module.model.pre_processors(batch, in_place=False) + input_tensor_0 = input_batch[ + self.sample_idx, + pl_module.multi_step - 1, + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data_0 = self.post_processors(input_tensor_0).numpy() + + # start rollout + with torch.no_grad(): + for rollout_step, (_, _, y_pred) in enumerate( + pl_module.rollout_step( + batch, + rollout=max(self.rollout), + validation_mode=False, + training_mode=False, + ), + ): + + if (rollout_step + 1) in self.rollout: + # prepare true output tensor for plotting + input_tensor_rollout_step = input_batch[ + self.sample_idx, + pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy() + + # prepare predicted output tensor for plotting + output_tensor = self.post_processors( + y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu(), + ).numpy() + + fig = plot_predicted_multilevel_flat_sample( + plot_parameters_dict, + self.per_sample, + self.latlons, + self.accumulation_levels_plot, + self.cmap_accumulation, + data_0.squeeze(), + data_rollout_step.squeeze(), + output_tensor[0, 0, :, :], # rolloutstep, first member + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_sample_rstep{rollout_step + 1:03d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_rstep{rollout_step + 1:03d}_rank{local_rank:01d}", + ) + LOGGER.info( + "Time taken to plot samples after longer rollout: %s seconds", + int(time.time() - start_time), + ) + + @rank_zero_only + def on_validation_batch_end( + self, + trainer, + pl_module, + output, + batch: torch.Tensor, + batch_idx: int, + ) -> None: + if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.every_n_epochs == 0: + precision_mapping = { + "16-mixed": torch.float16, + "bf16-mixed": torch.bfloat16, + } + prec = trainer.precision + dtype = precision_mapping.get(prec) + context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() + + if self.config.diagnostics.plot.asynchronous: + LOGGER.warning("Asynchronous plotting not supported for long rollout plots.") + + with context: + # Issue with running asyncronously, so call the plot function directly + self._plot(trainer, pl_module, output, batch, batch_idx, trainer.current_epoch) + + +class GraphTrainableFeaturesPlot(BasePerEpochPlotCallback): + """Visualize the node & edge trainable features defined.""" + + def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None: + """Initialise the GraphTrainableFeaturesPlot callback. + + Parameters + ---------- + config : OmegaConf + Config object + every_n_epochs: int | None, optional + Override for frequency to plot at, by default None + """ + super().__init__(config, every_n_epochs=every_n_epochs) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + epoch: int, + ) -> None: + _ = epoch + model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model + + fig = plot_graph_node_features(model) + + self._output_figure( + trainer.logger, + fig, + epoch=trainer.current_epoch, + tag="node_trainable_params", + exp_log_tag="node_trainable_params", + ) + + fig = plot_graph_edge_features(model) + + self._output_figure( + trainer.logger, + fig, + epoch=trainer.current_epoch, + tag="edge_trainable_params", + exp_log_tag="edge_trainable_params", + ) + + +class PlotLoss(BasePerBatchPlotCallback): + """Plots the unsqueezed loss over rollouts.""" + + def __init__( + self, + config: OmegaConf, + parameter_groups: dict[dict[str, list[str]]], + every_n_batches: int | None = None, + ) -> None: + """Initialise the PlotLoss callback. + + Parameters + ---------- + config : OmegaConf + Object with configuration settings + parameter_groups : dict + Dictionary with parameter groups with parameter names as keys + every_n_batches : int, optional + Override for batch frequency, by default None + + """ + super().__init__(config, every_n_batches=every_n_batches) + self.parameter_names = None + self.parameter_groups = parameter_groups + if self.parameter_groups is None: + self.parameter_groups = {} + + @cached_property + def sort_and_color_by_parameter_group( + self, + ) -> tuple[np.ndarray, np.ndarray, dict, list]: + """Sort parameters by group and prepare colors.""" + + def automatically_determine_group(name: str) -> str: + # first prefix of parameter name is group name + parts = name.split("_") + return parts[0] + + # group parameters by their determined group name for > 15 parameters + if len(self.parameter_names) <= 15: + # for <= 15 parameters, keep the full name of parameters + parameters_to_groups = np.array(self.parameter_names) + sort_by_parameter_group = np.arange(len(self.parameter_names), dtype=int) + else: + parameters_to_groups = np.array( + [ + next( + ( + group_name + for group_name, group_parameters in self.parameter_groups.items() + if name in group_parameters + ), + automatically_determine_group(name), + ) + for name in self.parameter_names + ], + ) + + unique_group_list, group_inverse, group_counts = np.unique( + parameters_to_groups, + return_inverse=True, + return_counts=True, + ) + + # join parameter groups that appear only once and are not given in config-file + unique_group_list = np.array( + [ + (unique_group_list[tn] if count > 1 or unique_group_list[tn] in self.parameter_groups else "other") + for tn, count in enumerate(group_counts) + ], + ) + parameters_to_groups = unique_group_list[group_inverse] + unique_group_list, group_inverse = np.unique(parameters_to_groups, return_inverse=True) + + # sort parameters by groups + sort_by_parameter_group = np.argsort(group_inverse, kind="stable") + + # apply new order to parameters + sorted_parameter_names = np.array(self.parameter_names)[sort_by_parameter_group] + parameters_to_groups = parameters_to_groups[sort_by_parameter_group] + unique_group_list, group_inverse, group_counts = np.unique( + parameters_to_groups, + return_inverse=True, + return_counts=True, + ) + + # get a color per group and project to parameter list + cmap = "tab10" if len(unique_group_list) <= 10 else "tab20" + if len(unique_group_list) > 20: + LOGGER.warning("More than 20 groups detected, but colormap has only 20 colors.") + # if all groups have count 1 use black color + bar_color_per_group = ( + np.tile("k", len(group_counts)) + if not np.any(group_counts - 1) + else plt.get_cmap(cmap)(np.linspace(0, 1, len(unique_group_list))) + ) + + # set x-ticks + x_tick_positions = np.cumsum(group_counts) - group_counts / 2 - 0.5 + xticks = dict(zip(unique_group_list, x_tick_positions)) + + legend_patches = [] + for group_idx, group in enumerate(unique_group_list): + text_label = f"{group}: " + string_length = len(text_label) + for ii in np.where(group_inverse == group_idx)[0]: + text_label += sorted_parameter_names[ii] + ", " + string_length += len(sorted_parameter_names[ii]) + 2 + if string_length > 50: + # linebreak after 50 characters + text_label += "\n" + string_length = 0 + legend_patches.append(mpatches.Patch(color=bar_color_per_group[group_idx], label=text_label[:-2])) + + return ( + sort_by_parameter_group, + bar_color_per_group[group_inverse], + xticks, + legend_patches, + ) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.Lightning_module, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + _ = batch_idx + + parameter_names = list(pl_module.data_indices.internal_model.output.name_to_index.keys()) + parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values()) + # reorder parameter_names by position + self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] + if not isinstance(pl_module.loss, BaseWeightedLoss): + logging.warning( + "Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.", + RuntimeWarning, + ) + + batch = pl_module.model.pre_processors(batch, in_place=False) + for rollout_step in range(pl_module.rollout): + y_hat = outputs[1][rollout_step] + y_true = batch[ + :, + pl_module.multi_step + rollout_step, + ..., + pl_module.data_indices.internal_data.output.full, + ] + loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() + + sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group + fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"loss_rstep_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", + exp_log_tag=f"loss_sample_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", + ) + + +class PlotSample(BasePerBatchPlotCallback): + """Plots a post-processed sample: input, target and prediction.""" + + def __init__( + self, + config: OmegaConf, + sample_idx: int, + parameters: list[str], + accumulation_levels_plot: list[float], + cmap_accumulation: list[str], + precip_and_related_fields: list[str] | None = None, + per_sample: int = 6, + every_n_batches: int | None = None, + ) -> None: + """Initialise the PlotSample callback. + + Parameters + ---------- + config : OmegaConf + Config object + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + accumulation_levels_plot : list[float] + Accumulation levels to plot + cmap_accumulation : list[str] + Colors of the accumulation levels + precip_and_related_fields : list[str] | None, optional + Precip variable names, by default None + per_sample : int, optional + Number of plots per sample, by default 6 + every_n_batches : int, optional + Batch frequency to plot at, by default None + """ + super().__init__(config, every_n_batches=every_n_batches) + self.sample_idx = sample_idx + self.parameters = parameters + + self.precip_and_related_fields = precip_and_related_fields + self.accumulation_levels_plot = accumulation_levels_plot + self.cmap_accumulation = cmap_accumulation + self.per_sample = per_sample + + LOGGER.info( + "Using defined accumulation colormap for fields: %s", + self.precip_and_related_fields, + ) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + + # Build dictionary of indices and parameters to be plotted + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + plot_parameters_dict = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in diagnostics, + ) + for name in self.parameters + } + + # When running in Async mode, it might happen that in the last epoch these tensors + # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA + # but internal ones would be on the cpu), The lines below allow to address this problem + if self.post_processors is None: + # Copy to be used across all the training cycle + self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + local_rank = pl_module.local_rank + + batch = pl_module.model.pre_processors(batch, in_place=False) + input_tensor = batch[ + self.sample_idx, + pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data = self.post_processors(input_tensor) + + output_tensor = self.post_processors( + torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), + in_place=False, + ) + output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) + data = data.numpy() + + for rollout_step in range(pl_module.rollout): + fig = plot_predicted_multilevel_flat_sample( + plot_parameters_dict, + self.per_sample, + self.latlons, + self.accumulation_levels_plot, + self.cmap_accumulation, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + precip_and_related_fields=self.precip_and_related_fields, + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", + ) + + +class BasePlotAdditionalMetrics(BasePerBatchPlotCallback): + """Base processing class for additional metrics.""" + + def process( + self, + pl_module: pl.LightningModule, + outputs: list, + batch: torch.Tensor, + ) -> tuple[np.ndarray, np.ndarray]: + # When running in Async mode, it might happen that in the last epoch these tensors + # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA + # but internal ones would be on the cpu), The lines below allow to address this problem + if self.pre_processors is None: + # Copy to be used across all the training cycle + self.pre_processors = copy.deepcopy(pl_module.model.pre_processors).cpu() + if self.post_processors is None: + # Copy to be used across all the training cycle + self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + + batch = pl_module.model.pre_processors(batch, in_place=False) + input_tensor = batch[ + self.sample_idx, + pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + + data = self.post_processors(input_tensor) + output_tensor = self.post_processors( + torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), + in_place=False, + ) + output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) + data = data.numpy() + return data, output_tensor + + +class PlotSpectrum(BasePlotAdditionalMetrics): + """Plots TP related metric comparing target and prediction. + + The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. + + - Power Spectrum + """ + + def __init__( + self, + config: OmegaConf, + sample_idx: int, + parameters: list[str], + every_n_batches: int | None = None, + ) -> None: + """Initialise the PlotSpectrum callback. + + Parameters + ---------- + config : OmegaConf + Config object + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + every_n_batches : int | None, optional + Override for batch frequency, by default None + """ + super().__init__(config, every_n_batches=every_n_batches) + self.sample_idx = sample_idx + self.parameters = parameters + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list, + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + + local_rank = pl_module.local_rank + data, output_tensor = self.process(pl_module, outputs, batch) + + for rollout_step in range(pl_module.rollout): + # Build dictionary of inidicies and parameters to be plotted + + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + plot_parameters_dict_spectrum = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in diagnostics, + ) + for name in self.parameters + } + + fig = plot_power_spectrum( + plot_parameters_dict_spectrum, + self.latlons, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", + ) + + +class PlotHistogram(BasePlotAdditionalMetrics): + """Plots histograms comparing target and prediction. + + The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. + """ + + def __init__( + self, + config: OmegaConf, + sample_idx: int, + parameters: list[str], + precip_and_related_fields: list[str] | None = None, + every_n_batches: int | None = None, + ) -> None: + """Initialise the PlotHistogram callback. + + Parameters + ---------- + config : OmegaConf + Config object + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + precip_and_related_fields : list[str] | None, optional + Precip variable names, by default None + every_n_batches : int | None, optional + Override for batch frequency, by default None + """ + super().__init__(config, every_n_batches=every_n_batches) + self.sample_idx = sample_idx + self.parameters = parameters + self.precip_and_related_fields = precip_and_related_fields + LOGGER.info( + "Using precip histogram plotting method for fields: %s.", + self.precip_and_related_fields, + ) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list, + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + + local_rank = pl_module.local_rank + data, output_tensor = self.process(pl_module, outputs, batch) + + for rollout_step in range(pl_module.rollout): + + # Build dictionary of inidicies and parameters to be plotted + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + + plot_parameters_dict_histogram = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in diagnostics, + ) + for name in self.parameters + } + + fig = plot_histogram( + plot_parameters_dict_histogram, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + self.precip_and_related_fields, + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", + ) diff --git a/src/anemoi/training/diagnostics/callbacks/profiler.py b/src/anemoi/training/diagnostics/callbacks/profiler.py new file mode 100644 index 00000000..13a698fa --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/profiler.py @@ -0,0 +1,98 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# ruff: noqa: ANN001 + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING +from typing import Any + +import torch +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities import rank_zero_only + +if TYPE_CHECKING: + import pytorch_lightning as pl + from pytorch_lightning.utilities.types import STEP_OUTPUT + +LOGGER = logging.getLogger(__name__) + + +class MemorySnapshotRecorder(Callback): + """Record memory snapshot using torch.cuda._record_memory_history().""" + + def __init__(self, config): + super().__init__() + self.config = config + self.dirpath = Path(self.config.hardware.paths.profiler) + + self.warmup = self.config.diagnostics.benchmark_profiler.snapshot.warmup + if not self.warmup: + self.warmup = 0 + self.num_steps = ( + self.config.diagnostics.benchmark_profiler.snapshot.steps + self.warmup + ) # be consistent with profiler scheduler + self.status = False + + assert ( + self.num_steps % self.config.dataloader.batch_size.training == 0 + ), "Snapshot steps is not a multiple of batch size" + assert ( + self.warmup % self.config.dataloader.batch_size.training == 0 + ), "Snapshot Warmup steps is not a multiple of batch size" + + @rank_zero_only + def _start_snapshot_recording(self) -> None: + LOGGER.info("Starting snapshot record_memory_history") + torch.cuda.memory._record_memory_history() + self.status = True + + @rank_zero_only + def _save_snapshot(self) -> None: + self.memory_snapshot_fname = self.dirpath / "memory_snapshot.pickle" + try: + LOGGER.info("Saving memory snapshot to %s", self.memory_snapshot_fname) + torch.cuda.memory._dump_snapshot(f"{self.memory_snapshot_fname}") + except BaseException: + LOGGER.exception("Failed to capture memory snapshot") + + @rank_zero_only + def stop_record_memory_history(self) -> None: + LOGGER.info("Stopping snapshot record_memory_history") + torch.cuda.memory._record_memory_history(enabled=None) + + def on_train_batch_start( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + batch: Any, + batch_idx: int, + ) -> None: + del pl_module, batch, batch_idx + if trainer.global_step == self.warmup: + self._start_snapshot_recording() + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + del batch, batch_idx, pl_module, outputs + if trainer.global_step == self.num_steps: + if self.status is True: + self._save_snapshot() + self.stop_record_memory_history() + else: + LOGGER.info("Snapshot recording was not started so no snapshot was saved") diff --git a/src/anemoi/training/diagnostics/callbacks/provenance.py b/src/anemoi/training/diagnostics/callbacks/provenance.py new file mode 100644 index 00000000..414f0311 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/provenance.py @@ -0,0 +1,47 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from pytorch_lightning.callbacks import Callback + +if TYPE_CHECKING: + import pytorch_lightning as pl + import torch + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class ParentUUIDCallback(Callback): + """A callback that retrieves the parent UUID for a model, if it is a child model.""" + + def __init__(self, config: OmegaConf) -> None: + """Initialise the ParentUUIDCallback callback. + + Parameters + ---------- + config : OmegaConf + Config object + + """ + super().__init__() + self.config = config + + def on_load_checkpoint( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + checkpoint: torch.nn.Module, + ) -> None: + del trainer # unused + pl_module.hparams["metadata"]["parent_uuid"] = checkpoint["hyper_parameters"]["metadata"]["uuid"] diff --git a/src/anemoi/training/diagnostics/logger.py b/src/anemoi/training/diagnostics/logger.py index 4e4a35c1..698c7c50 100644 --- a/src/anemoi/training/diagnostics/logger.py +++ b/src/anemoi/training/diagnostics/logger.py @@ -73,7 +73,10 @@ def get_mlflow_logger(config: DictConfig) -> None: ) config_params = OmegaConf.to_container(config, resolve=True) - logger.log_hyperparams(config_params) + logger.log_hyperparams( + config_params, + expand_keys=config.diagnostics.log.mlflow.get("expand_hyperparams", ["config"]), + ) if config.diagnostics.log.mlflow.terminal: logger.log_terminal_output(artifact_save_dir=config.hardware.paths.plots) diff --git a/src/anemoi/training/diagnostics/mlflow/__init__.py b/src/anemoi/training/diagnostics/mlflow/__init__.py index 282d6a69..c167afa2 100644 --- a/src/anemoi/training/diagnostics/mlflow/__init__.py +++ b/src/anemoi/training/diagnostics/mlflow/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 183e7a0d..7c482ce1 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -30,6 +30,7 @@ from pytorch_lightning.utilities.rank_zero import rank_zero_only from anemoi.training.diagnostics.mlflow.auth import TokenAuth +from anemoi.training.diagnostics.mlflow.utils import expand_iterables from anemoi.training.diagnostics.mlflow.utils import health_check from anemoi.training.utils.jsonify import map_config_to_primitives @@ -432,10 +433,59 @@ def experiment(self) -> MLFlowLogger.experiment: def log_system_metrics(self) -> None: """Log system metrics (CPU, GPU, etc).""" import mlflow + import psutil + from mlflow.system_metrics.metrics.base_metrics_monitor import BaseMetricsMonitor + from mlflow.system_metrics.metrics.disk_monitor import DiskMonitor + from mlflow.system_metrics.metrics.gpu_monitor import GPUMonitor + from mlflow.system_metrics.metrics.network_monitor import NetworkMonitor from mlflow.system_metrics.system_metrics_monitor import SystemMetricsMonitor + class CustomCPUMonitor(BaseMetricsMonitor): + """Class for monitoring CPU stats. + + Extends default CPUMonitor, to also measure total \ + memory and a different formula for calculating used memory. + + """ + + def collect_metrics(self) -> None: + # Get CPU metrics. + cpu_percent = psutil.cpu_percent() + self._metrics["cpu_utilization_percentage"].append(cpu_percent) + + system_memory = psutil.virtual_memory() + # Change the formula for measuring CPU memory usage + # By default Mlflow uses psutil.virtual_memory().used + # Tests have shown that "used" underreports memory usage by as much as a factor of 2, + # "used" also misses increased memory usage from using a higher prefetch factor + self._metrics["system_memory_usage_megabytes"].append( + (system_memory.total - system_memory.available) / 1e6, + ) + self._metrics["system_memory_usage_percentage"].append(system_memory.percent) + + # QOL: report the total system memory in raw numbers + self._metrics["system_memory_total_megabytes"].append(system_memory.total / 1e6) + + def aggregate_metrics(self) -> dict[str, int]: + return {k: round(sum(v) / len(v), 1) for k, v in self._metrics.items()} + + class CustomSystemMetricsMonitor(SystemMetricsMonitor): + def __init__(self, run_id: str, resume_logging: bool = False): + super().__init__(run_id, resume_logging=resume_logging) + + # Replace the CPUMonitor with custom implementation + self.monitors = [CustomCPUMonitor(), DiskMonitor(), NetworkMonitor()] + try: + gpu_monitor = GPUMonitor() + self.monitors.append(gpu_monitor) + except ImportError: + LOGGER.warning( + "`pynvml` is not installed, to log GPU metrics please run `pip install pynvml` \ + to install it", + ) + mlflow.enable_system_metrics_logging() - system_monitor = SystemMetricsMonitor( + system_monitor = CustomSystemMetricsMonitor( self.run_id, resume_logging=self.run_id is not None, ) @@ -483,8 +533,39 @@ def _clean_params(params: dict[str, Any]) -> dict[str, Any]: return params @rank_zero_only - def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: - """Overwrite the log_hyperparams method to flatten config params using '.'.""" + def log_hyperparams_as_artifact(self, params: dict[str, Any] | Namespace) -> None: + """Log hyperparameters as an artifact.""" + import json + import tempfile + from json import JSONEncoder + + class StrEncoder(JSONEncoder): + def default(self, o: Any) -> str: + return str(o) + + with tempfile.TemporaryDirectory() as tmp_dir: + path = Path(tmp_dir) / "config.json" + with Path.open(path, "w") as f: + json.dump(params, f, cls=StrEncoder) + self.experiment.log_artifact(run_id=self.run_id, local_path=path) + + @rank_zero_only + def log_hyperparams(self, params: dict[str, Any] | Namespace, *, expand_keys: list[str] | None = None) -> None: + """Overwrite the log_hyperparams method. + + - flatten config params using '.'. + - expand keys within params to avoid truncation. + - log hyperparameters as an artifact. + + Parameters + ---------- + params : dict[str, Any] | Namespace + params to log + expand_keys : list[str] | None, optional + keys to expand within params. Any key being expanded will + have lists converted according to `expand_iterables`, + by default None. + """ if self._flag_log_hparams: params = _convert_params(params) @@ -492,17 +573,34 @@ def log_hyperparams(self, params: dict[str, Any] | Namespace) -> None: if config := params.get("config"): params["config"] = map_config_to_primitives(config) - params = _flatten_dict(params, delimiter=".") # Flatten dict with '.' to not break API queries - params = self._clean_params(params) - import mlflow from mlflow.entities import Param - # Truncate parameter values. truncation_length = 250 + if Version(mlflow.VERSION) >= Version("1.28.0"): truncation_length = 500 - params_list = [Param(key=k, value=str(v)[:truncation_length]) for k, v in params.items()] + + self.log_hyperparams_as_artifact(params) + + expanded_params = {} + params = params.copy() + + for key in expand_keys or []: + if key in params: + expanded_params.update( + expand_iterables(params.pop(key), size_threshold=None, delimiter="."), + ) + expanded_params.update(params) + + expanded_params = _flatten_dict( + expanded_params, + delimiter=".", + ) # Flatten dict with '.' to not break API queries + expanded_params = self._clean_params(expanded_params) + + # Truncate parameter values. + params_list = [Param(key=k, value=str(v)[:truncation_length]) for k, v in expanded_params.items()] for idx in range(0, len(params_list), 100): self.experiment.log_batch(run_id=self.run_id, params=params_list[idx : idx + 100]) diff --git a/src/anemoi/training/diagnostics/mlflow/utils.py b/src/anemoi/training/diagnostics/mlflow/utils.py index 89f6e002..929183bc 100644 --- a/src/anemoi/training/diagnostics/mlflow/utils.py +++ b/src/anemoi/training/diagnostics/mlflow/utils.py @@ -6,9 +6,11 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from __future__ import annotations - +import functools import os +from typing import Any import requests @@ -36,3 +38,79 @@ def health_check(tracking_uri: str) -> None: if not token: error_msg += "The server may require authentication, did you forget to turn it on?" raise ConnectionError(error_msg) + + +def expand_iterables( + params: dict[str, Any], + *, + size_threshold: int | None = None, + recursive: bool = True, + delimiter: str = ".", +) -> dict[str, Any]: + """Expand any iterable values to the form {key.i: value_i}. + + If expanded will also add {key.all: [value_0, value_1, ...], key.length: len([value_0, value_1, ...])}. + + If `size_threshold` is not None, expand the iterable only if the length of str(value) is + greater than `size_threshold`. + + Parameters + ---------- + params : dict[str, Any] + Parameters to be expanded. + size_threshold : int | None, optional + Threshold of str(value) to expand iterable at. + Default is None. + recursive : bool, optional + Expand nested dictionaries. + Default is True. + delimiter: str, optional + Delimiter to use for keys. + Default is ".". + + Returns + ------- + dict[str, Any] + Dictionary with all iterable values expanded. + + Examples + -------- + >>> expand_iterables({'a': ['a', 'b', 'c']}) + {'a.0': 'a', 'a.1': 'b', 'a.2': 'c', 'a.all': ['a', 'b', 'c'], 'a.length': 3} + >>> expand_iterables({'a': {'b': ['a', 'b', 'c']}}) + {'a': {'b.0': 'a', 'b.1': 'b', 'b.2': 'c', 'b.all': ['a', 'b', 'c'], 'b.length': 3}} + >>> expand_iterables({'a': ['a', 'b', 'c']}, size_threshold=100) + {'a': ['a', 'b', 'c']} + >>> expand_iterables({'a': [[0,1,2], 'b', 'c']}) + {'a.0': {0: 0, 1: 1, 2: 2}, 'a.1': 'b', 'a.2': 'c', 'a.all': [[0, 1, 2], 'b', 'c'], 'a.length': 3} + """ + + def should_be_expanded(x: Any) -> bool: + return size_threshold is None or len(str(x)) > size_threshold + + nested_func = functools.partial(expand_iterables, size_threshold=size_threshold, recursive=recursive) + + def expand(val: dict | list) -> dict[str, Any]: + if not recursive: + return val + if isinstance(val, dict): + return nested_func(val) + if isinstance(val, list): + return nested_func(dict(enumerate(val))) + return val + + expanded_params = {} + + for key, value in params.items(): + if isinstance(value, (list, tuple)): + if should_be_expanded(value): + for i, v in enumerate(value): + expanded_params[f"{key}{delimiter}{i}"] = expand(v) + + expanded_params[f"{key}{delimiter}all"] = value + expanded_params[f"{key}{delimiter}length"] = len(value) + else: + expanded_params[key] = value + else: + expanded_params[key] = expand(value) + return expanded_params diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 7b4ba711..dde80018 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -16,6 +16,8 @@ import matplotlib.pyplot as plt import matplotlib.style as mplstyle import numpy as np +from anemoi.models.layers.mapper import GraphEdgeMixin +from matplotlib.collections import LineCollection from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap from matplotlib.colors import TwoSlopeNorm @@ -28,6 +30,7 @@ if TYPE_CHECKING: from matplotlib.figure import Figure + from torch import nn from dataclasses import dataclass @@ -587,7 +590,7 @@ def scatter_plot( Parameters ---------- - fig : _type_ + fig : Figure Figure object handle ax : matplotlib.axes Axis object handle @@ -628,36 +631,144 @@ def scatter_plot( fig.colorbar(psc, ax=ax) -def plot_graph_features( - latlons: np.ndarray, - features: np.ndarray, -) -> Figure: - """Plot trainable graph features. +def edge_plot( + fig: Figure, + ax: plt.Axes, + src_coords: np.ndarray, + dst_coords: np.ndarray, + data: np.ndarray, + cmap: str = "coolwarm", + title: str | None = None, +) -> None: + """Lat-lon line plot. Parameters ---------- - latlons : np.ndarray - Latitudes and longitudes - features : np.ndarray - Trainable Features + fig : _type_ + Figure object handle + ax : _type_ + Axis object handle + src_coords : np.ndarray of shape (num_edges, 2) + Source latitudes and longitudes. + dst_coords : np.ndarray of shape (num_edges, 2) + Destination latitudes and longitudes. + data : np.ndarray of shape (num_edges, 1) + Data to plot + cmap : str, optional + Colormap string from matplotlib, by default "viridis". + title : str, optional + Title for plot, by default None + """ + edge_lines = np.stack([src_coords, dst_coords], axis=1) + lc = LineCollection(edge_lines, cmap=cmap, linewidths=1) + lc.set_array(data) + + psc = ax.add_collection(lc) + + xmin, xmax = edge_lines[:, 0, 0].min(), edge_lines[:, 0, 0].max() + ymin, ymax = edge_lines[:, 1, 1].min(), edge_lines[:, 1, 1].max() + ax.set_xlim((xmin - 0.1, xmax + 0.1)) + ax.set_ylim((ymin - 0.1, ymax + 0.1)) + + continents.plot_continents(ax) + + if title is not None: + ax.set_title(title) + + ax.set_aspect("auto", adjustable=None) + _hide_axes_ticks(ax) + fig.colorbar(psc, ax=ax) + + +def plot_graph_node_features(model: nn.Module) -> Figure: + """Plot trainable graph node features. + + Parameters + ---------- + model: AneomiModelEncProcDec + Model object Returns ------- Figure Figure object handle + """ + nrows = len(nodes_name := model._graph_data.node_types) + ncols = min(model.node_attributes.trainable_tensors[m].trainable.shape[1] for m in nodes_name) + figsize = (ncols * 4, nrows * 3) + fig, ax = plt.subplots(nrows, ncols, figsize=figsize) + + for row, (mesh, trainable_tensor) in enumerate(model.node_attributes.trainable_tensors.items()): + latlons = model.node_attributes.get_coordinates(mesh).cpu().numpy() + node_features = trainable_tensor.trainable.cpu().detach().numpy() + + lat, lon = latlons[:, 0], latlons[:, 1] + + for i in range(ncols): + ax_ = ax[row, i] if ncols > 1 else ax[row] + scatter_plot( + fig, + ax_, + lon=lon, + lat=lat, + data=node_features[..., i], + title=f"{mesh} trainable feature #{i + 1}", + ) + + return fig + + +def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> Figure: + """Plot trainable graph edge features. + + Parameters + ---------- + model: AneomiModelEncProcDec + Model object + q_extreme_limit : float, optional + Plot top & bottom quantile of edges trainable values, by default 0.05 (5%). + Returns + ------- + Figure + Figure object handle """ - nplots = features.shape[-1] - figsize = (nplots * 4, 3) - fig, ax = plt.subplots(1, nplots, figsize=figsize) + trainable_modules = { + (model._graph_name_data, model._graph_name_hidden): model.encoder, + (model._graph_name_hidden, model._graph_name_data): model.decoder, + } - lat, lon = latlons[:, 0], latlons[:, 1] + if isinstance(model.processor, GraphEdgeMixin): + trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor - pc = EquirectangularProjection() - pc_lon, pc_lat = pc(lon, lat) + ncols = min(module.trainable.trainable.shape[1] for module in trainable_modules.values()) + nrows = len(trainable_modules) + figsize = (ncols * 4, nrows * 3) + fig, ax = plt.subplots(nrows, ncols, figsize=figsize) + + for row, ((src, dst), graph_mapper) in enumerate(trainable_modules.items()): + src_coords = model.node_attributes.get_coordinates(src).cpu().numpy() + dst_coords = model.node_attributes.get_coordinates(dst).cpu().numpy() + edge_index = graph_mapper.edge_index_base.cpu().numpy() + edge_features = graph_mapper.trainable.trainable.cpu().detach().numpy() + + for i in range(ncols): + ax_ = ax[row, i] if ncols > 1 else ax[row] + feature = edge_features[..., i] + + # Get mask of feature values over top and bottom percentiles + top_perc = np.quantile(feature, 1 - q_extreme_limit) + bottom_perc = np.quantile(feature, q_extreme_limit) - for i in range(nplots): - ax_ = ax[i] if nplots > 1 else ax - scatter_plot(fig, ax_, lon=pc_lon, lat=pc_lat, data=features[..., i]) + mask = (feature >= top_perc) | (feature <= bottom_perc) + + edge_plot( + fig, + ax_, + src_coords[edge_index[0, mask]][:, ::-1], + dst_coords[edge_index[1, mask]][:, ::-1], + feature[mask], + title=f"{src} -> {dst} trainable feature #{i + 1}", + ) return fig diff --git a/src/anemoi/training/diagnostics/profilers.py b/src/anemoi/training/diagnostics/profilers.py new file mode 100644 index 00000000..96a6f3c4 --- /dev/null +++ b/src/anemoi/training/diagnostics/profilers.py @@ -0,0 +1,713 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +import re +from pathlib import Path +from typing import TYPE_CHECKING +from typing import Any + +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +from pytorch_lightning.callbacks import TQDMProgressBar +from pytorch_lightning.profilers import Profiler +from pytorch_lightning.profilers import PyTorchProfiler +from pytorch_lightning.profilers import SimpleProfiler +from pytorch_lightning.utilities import rank_zero_only + +if TYPE_CHECKING: + import importlib + + import pytorch_lightning as pl + from omegaconf import DictConfig + from pytorch_lightning.utilities.types import STEP_OUTPUT + + from anemoi.training.train.forecaster import GraphForecaster + + if importlib.util.find_spec("ipywidgets") is not None: + from tqdm.auto import tqdm as _tqdm + else: + from tqdm import tqdm as _tqdm + +from torch.profiler import profile + +from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger + +LOGGER = logging.getLogger(__name__) + + +def check_torch_version() -> bool: + torch_version = torch.__version__ + version_nums = torch_version.split(".") + major_version = int(version_nums[0]) + minor_version = int(version_nums[1]) + if major_version == 2 and minor_version >= 1: + return True + LOGGER.error("Memory snapshot is only supported for torch >= 2.1") + return False + + +def convert_to_seconds(time_str: str) -> float: + import re + + pattern = r"(\d+(\.\d+)?)\s*([a-zA-Z]+)" + # Use regex to find matches + match = re.match(pattern, time_str) + + # Check if match is found + if match: + # Extract numeric part and unit part + numeric_part = float(match.group(1)) + unit = match.group(3) + + # Convert the unit to seconds + if unit == "s": + return numeric_part + if unit == "ds": + return numeric_part / 10 # Convert decaseconds to seconds + if unit == "cs": + return numeric_part / 100 # Convert centiseconds to seconds + if unit == "ms": + return numeric_part / 1000 # Convert milliseconds to seconds + error_msg = ( + "Invalid unit. Supported units are: 's' (seconds)'" + "'ds' (decaseconds), 'cs' (centiseconds) and 'ms' (miliseconds) .", + ) + raise ValueError(error_msg) + error_msg = "Invalid time format. The time should be in the format: 'numeric_part unit'. For example: '10 ms'" + raise ValueError(error_msg) + + +PROFILER_ACTIONS = [ + r"\[Strategy]\w+\.batch_to_device", + r"\[Strategy]\w+\.backward", + r"\[Strategy]\w+\.training_step", + r"\[Strategy]\w+\.validation_step", + r"\[Strategy]\w+\.batch_to_device", + "run_training_epoch", + "run_training_batch", + r"\[_EvaluationLoop\]\.\w+", + r"\[_TrainingEpochLoop\]\.\w+", + r"\[LightningDataModule]\w+\.train_dataloader", + r"\[LightningDataModule]\w+\.val_dataloader", + r"\[LightningDataModule]\w+\.state_dict", + r"\[LightningDataModule]\w+\.setup", + r"\[LightningDataModule]\w+\.prepare_data", + r"\[LightningDataModule]\w+\.teardown", + r"\[LightningModule]\w+\.optimizer_step", + r"\[LightningModule]\w+\.configure_gradient_clipping", + r"\[LightningModule]\w+\.on_validation_model_eval", + r"\[LightningModule]\w+\.optimizer_zero_grad", + r"\[LightningModule]\w+\.transfer_batch_to_device", + r"\[LightningModule]\w+\.on_validation_model_train", + r"\[LightningModule]\w+\.configure_optimizers", + r"\[LightningModule]\w+\.lr_scheduler_step", + r"\[LightningModule]\w+\.configure_sharded_model", + r"\[LightningModule]\w+\.setup", + r"\[LightningModule]\w+\.prepare_data", + r"\[Callback\](.*Plot*)", + r"\[Callback\](.*Checkpoint*)", +] + +GPU_METRICS_DICT = { + "GPU device utilization (%)": "gpu", + "GPU memory use (%)": "memory", + "GPU memory allocated (%)": "memoryAllocated", + "GPU memory allocated (GB)": "memoryAllocatedBytes", +} + + +class WandBSystemSummarizer: + """Summarize System Metrics provided by W&B logger.""" + + def __init__(self, wandb_logger: pl.loggers.WandbLogger): + + run_dict = wandb_logger._wandb_init + self.run_id_path = f"{run_dict['entity']}/{run_dict['project']}/{run_dict['id']}" + + def get_wandb_metrics(self) -> (pd.DataFrame, dict): + """Fetches system metrics and metadata from a W&B run.""" + import wandb + + run = wandb.Api().run(self.run_id_path) + system_metrics = run.history(stream="events") + metadata_dict = run.metadata + system_metrics = system_metrics.dropna() + return system_metrics, metadata_dict + + def summarize_gpu_metrics(self, df: pd.DataFrame) -> dict[str, float]: + """Given the System Metrics DataFrame, summarized the GPU metrics. + + - gpu.{gpu_index}.memory - GPU memory utilization in percent for each GPU + - gpu.{gpu_index}.memoryAllocated - GPU memory allocated as % of the total available memory for each GPU + - gpu.{gpu_index}.memoryAllocatedBytes - GPU memory allocated in bytes for each GPU + - gpu.{gpu_index}.gpu - GPU utilization in percent for each GPU + """ + average_metric = {} + col_names = df.columns + for gpu_metric_name, gpu_metric in GPU_METRICS_DICT.items(): + pattern = rf"system.gpu.\d.{gpu_metric}$" + sub_gpu_cols = [string for string in col_names if re.match(pattern, string)] + metrics_per_gpu = df[sub_gpu_cols].mean(axis=0) + if gpu_metric == "memoryAllocatedBytes": + metrics_per_gpu = metrics_per_gpu * 1e-9 + average_metric[gpu_metric_name] = metrics_per_gpu.mean() + # Just add metrics per gpu to the report if we have more than 1 GPU + if metrics_per_gpu.shape[0] > 1: + metrics_per_gpu.index = [" " + index for index in metrics_per_gpu.index] + average_metric.update(dict(metrics_per_gpu)) + return average_metric + + def summarize_system_metrics(self) -> dict[str, float]: + r"""Summarizes the System metrics from a W&B run. + + Some of the metrics included are: + - cpu.{}.cpu_percent - CPU usage of the system on a per-core basis. + - system.memory - Represents the total system memory usage as a percentage of the total available memory. + - system.cpu - Percentage of CPU usage by the process, normalized by the number of available CPUs + - system.disk.\\.usageGB - (Represents the total system disk usage in gigabytes (GB)) + - system.proc.memory.percent - Indicates the memory usage of the process as a % of the total available memory + + More information about W&B system metrics can be found here: + https://docs.wandb.ai/guides/app/features/system-metrics + """ + system_metrics_df, metadata_dict = self.get_wandb_metrics(self.run_id_path) + + col_names = system_metrics_df.columns + system_metrics = {} + + n_cpus = metadata_dict["cpu_count"] + cpu_cols = list(filter(lambda k: "cpu." in k, col_names)) + system_metrics["avg CPU usage (%)"] = (system_metrics_df[cpu_cols].sum(axis=1) / n_cpus).mean() + + system_metrics_gpu = self.summarize_gpu_metrics(system_metrics_df) + system_metrics.update(system_metrics_gpu) + + system_metrics["avg Memory usage (%)"] = system_metrics_df["system.memory"].mean() + system_metrics["avg Disk usage (GB)"] = system_metrics_df["system.disk.\\.usageGB"].mean() + system_metrics["avg Disk usage (%)"] = system_metrics_df["system.disk.\\.usagePercent"].mean() + + system_metrics["execution time (sec)"] = system_metrics_df["_runtime"].iloc[-1] # in seconds + return system_metrics + + +class MLFlowSystemSummarizer: + """Summarize System Metrics provided by MlFlow logger.""" + + def __init__(self, mlflow_logger: pl.loggers.MLFlowLogger): + self.run_id = mlflow_logger.run_id + self.mlflow_client = mlflow_logger._mlflow_client + + @property + def system_metrics(self) -> list[str]: + run = self.mlflow_client.get_run(self.run_id) + return [metric for metric in run.data.metrics if "system" in metric] + + def _clean_metric_name(self, metric_name: str) -> str: + return ( + metric_name.replace("system.", "avg ") + .replace("_", " ") + .replace("megabytes", "MB") + .replace("percentage", "%") + ) + + def _get_mean(self, pattern: str, df: pd.DataFrame) -> float: + # Filter rows containing the pattern in the 'metric' column + filtered_rows = df[df["metric"].str.contains(pattern)] + return filtered_rows.loc[:, "value"].astype(np.float32).mean() + + def _extract_gpu_metrics(self, df: pd.DataFrame) -> pd.DataFrame: + # Define the pattern you want to search for + pattern = r"gpu\s\d+\s+utilization" + df.loc[len(df.index)] = ["avg GPU utilization (%)", self._get_mean(pattern, df)] + + pattern = r"gpu\s\d+\s+memory\s+usage\s+%" + df.loc[len(df.index)] = ["avg GPU memory usage %", self._get_mean(pattern, df)] + + pattern = r"gpu\s\d+\s+memory\s+usage\s+MB" + df.loc[len(df.index)] = ["avg GPU memory usage MB", self._get_mean(pattern, df)] + + return df + + def summarize_mlflow_system_metrics(self) -> pd.DataFrame: + rows = [] + for metric in self.system_metrics: + metric = self.mlflow_client.get_metric_history(self.run_id, metric) + avg_value = sum(m.value for m in metric) / len(metric) + metric_name = self._clean_metric_name(metric[0].key) + rows.append({"metric": metric_name, "value": f"{avg_value:.2f}"}) + return self._extract_gpu_metrics(pd.DataFrame(rows)) + + +class DummyProfiler(Profiler): + """Placeholder profiler.""" + + def __init__(self): + super().__init__() + + def start(self, *args, **kwargs) -> None: + pass + + def stop(self, *args, **kwargs) -> None: + pass + + +def _convert_npint_to_int(obj: Any) -> dict | list | int | str | float: + """Recursively converts all np.int64 values in the input to Python int.""" + # Recursively converts all np.int64 to int + if isinstance(obj, dict): + return {k: _convert_npint_to_int(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_convert_npint_to_int(item) for item in obj] + if isinstance(obj, np.integer): + return int(obj) # Convert np.int64 to int + return obj + + +class PatchedProfile(profile): + + def _get_distributed_info(self) -> dict[str, str]: + dist_info = super()._get_distributed_info() + return _convert_npint_to_int(dist_info) + + +class BenchmarkProfiler(Profiler): + """Custom PyTorch Lightning profiler for benchmarking.""" + + def __init__(self, config: DictConfig) -> None: + super().__init__(config) + + self.config = config + self.warmup = self.config.diagnostics.benchmark_profiler.memory.warmup + if not self.warmup: + self.warmup = 0 + self.num_steps = self.config.diagnostics.benchmark_profiler.memory.steps + + if self.config.diagnostics.benchmark_profiler.memory.extra_plots: + assert ( + self.num_steps <= self.config.training.num_sanity_val_steps + ), "Sanity steps should be less than snapshot steps, to avoid memory issues" + + self.dirpath = None + self.create_output_path() + # the profilers need to be initialised before the setup method because + # actions like configuring callbacks would trigger the profiler + self.memory_profiler = DummyProfiler # dummy profiler to be used as placeholder + self.time_profiler = DummyProfiler # dummy profiler to be used as placeholder + + @rank_zero_only + def create_output_path(self) -> None: + self.dirpath = Path(self.config.hardware.paths.profiler) + self.dirpath.mkdir(parents=True, exist_ok=True) + + def broadcast_profiler_path(self, string_var: str, src_rank: int) -> str: + from lightning_fabric.utilities.distributed import group as _group + + string_var = [string_var] + dist.broadcast_object_list(string_var, src_rank, group=_group.WORLD) + return string_var[0] + + def setup(self, stage: str, local_rank: int | None = None, log_dir: str | None = None) -> None: + del log_dir + # THE STRATEGY IS ALREADY INITIALISED AND TORCH DISTRIBUTED IS ACTIVE + # we need to broadcast the profiler path to all ranks to save the memory traces + self.dirpath = Path(self.broadcast_profiler_path(str(self.dirpath), 0)) + self._stage = stage + self._local_rank = local_rank + self._create_time_profilers() + self._create_memory_profilers() + + def _create_time_profilers(self) -> None: + """Creates profilers for time and memory measurements.""" + if self.config.diagnostics.benchmark_profiler.time.enabled: + self.time_profiler = SimpleProfiler( + dirpath=self.dirpath, + ) + + def _create_memory_profilers(self) -> None: + if self.config.diagnostics.benchmark_profiler.memory.enabled: + import os + + def trace_handler(dir_name: str, stage: str | None = None) -> callable: + + def handler_fn(prof: pl.profilers.Profiler) -> None: + import socket + import time + + worker_name = f"{socket.gethostname()}_{os.getpid()}" + file_name = str(dir_name / f"{worker_name}.{stage}.{time.time_ns()}.pt.trace.json") + LOGGER.info("Saving memory trace to %s", file_name) + prof.export_chrome_trace(file_name) + + return handler_fn + + global_rank = int(os.environ.get("SLURM_PROCID", "0")) # WON'T WORK WHEN RUNNING WITHOUT SLURM + if not (self.config.diagnostics.benchmark_profiler.memory.trace_rank0_only and global_rank != 0): + from pytorch_lightning.profilers.pytorch import _KINETO_AVAILABLE + + assert ( + _KINETO_AVAILABLE + ), "Kineto is not available. Please ensure Kineto is avaialble to be able to use the memory profiler" + + torch.profiler.profile = ( + PatchedProfile # patch the profile(KinetoProfile) object to serialise the distributed info + ) + self.memory_profiler = PyTorchProfiler( + with_stack=True, + emit_nvtx=False, + profile_memory=True, + export_to_chrome=True, + record_shapes=True, + group_by_input_shapes=True, + dirpath=self.dirpath, + on_trace_ready=trace_handler(self.dirpath), + schedule=torch.profiler.schedule( + wait=0, + warmup=self.warmup, + active=self.num_steps, + repeat=1, + skip_first=self.config.training.num_sanity_val_steps, + ), + ) + self.time_rows_dict = None # updated if we create a memory profile report + + def start(self, action_name: str) -> None: + """Starts recording for a specific action. + + Parameters + ---------- + action_name : str + Name of the action. + """ + self.time_profiler.start(action_name) + self.memory_profiler.start(action_name) + + def stop(self, action_name: str) -> None: + """Stops recording for a specific action. + + Parameters + ---------- + action_name : str + Name of the action. + """ + self.time_profiler.stop(action_name) + self.memory_profiler.stop(action_name) + + def _trim_time_report(self, recorded_actions: dict) -> dict[str, float]: + all_actions_names = recorded_actions.keys() + df = pd.DataFrame({"Strings": all_actions_names}) + combined_pattern = "|".join(PROFILER_ACTIONS) + filtered_df = df[df["Strings"].str.contains(combined_pattern, regex=True, na=False)] + trimmed_actions_names = filtered_df["Strings"].tolist() + return {key: recorded_actions[key] for key in trimmed_actions_names} + + def get_time_profiler_df(self, precision: int = 5) -> pd.DataFrame: + """Retrieves a DataFrame with time profiling information. + + Parameters + ---------- + precision : int + Precision for rounding, by default 5 + + Returns + ------- + pd.DataFrame + DataFrame with time profiling information. + """ + if self.config.diagnostics.benchmark_profiler.time.verbose is False: + self.time_profiler.recorded_durations = self._trim_time_report( + recorded_actions=self.time_profiler.recorded_durations, + ) + time_df = pd.DataFrame(self.time_profiler.recorded_durations.items()) + time_df[2] = time_df[1].apply(len) + time_df[3] = time_df[1].apply(np.mean) + time_df[1] = time_df[1].apply(sum) + time_df.columns = ["name", "total_time", "n_calls", "avg_time"] + + def replace_function(value: str) -> str: + # Replace 'apple' with 'fruit' + return re.sub(r"\{.*?\}", "", value) # Remove anything between braces + + time_df["name"] = time_df["name"].apply(replace_function) + pattern = r"\[(.*?)\]|(.*)" + time_df["category"] = time_df["name"].str.extract(pattern, expand=False)[0].fillna(time_df["name"]) + + pattern = re.compile(r"\[Callback\](.*?)\.") + # Apply the regular expression to the column + callbacks_subcategories = "*Callback_" + time_df[time_df["category"] == "Callback"]["name"].str.extract(pattern) + indexer = time_df[time_df["category"] == "Callback"].index + time_df.loc[indexer, "category"] = callbacks_subcategories[0].tolist() + + # Check if 'Callback' is present in the 'category' column + time_df["is_callback"] = time_df["category"].str.contains("Callback", case=False) + + # Group by the 'is_callback' column and apply groupby operation only on rows with 'Callback' in 'category' + grouped_data = ( + time_df[time_df["is_callback"]] + .groupby("category") + .agg({"n_calls": "sum", "avg_time": "sum", "total_time": "sum"}) + .reset_index() + ) + grouped_data["name"] = grouped_data["category"] + + time_df = pd.concat([time_df[~time_df["is_callback"]], grouped_data]) + time_df = time_df.drop("is_callback", axis=1) + time_df = time_df.round(precision) + time_df = time_df.sort_values(by="category", ascending=False) + + self.time_report_fname = self.dirpath / "time_profiler.csv" + self._save_report(time_df, self.time_report_fname) + return time_df + + @staticmethod + def to_df(sample_dict: dict[str, float], precision: str = ".5") -> pd.DataFrame: + df = pd.DataFrame(sample_dict.items()) + df.columns = ["metric", "value"] + df.value = df.value.apply(lambda x: f"%{precision}f" % x) + return df + + @rank_zero_only + def get_system_profiler_df(self, logger_name: str, logger: pl.loggers.Logger) -> pd.DataFrame: + if logger_name == "wandb": + system_metrics_df = self.to_df(WandBSystemSummarizer(logger).summarize_system_metrics()) + elif logger_name == "mlflow": + system_metrics_df = MLFlowSystemSummarizer(logger).summarize_mlflow_system_metrics() + elif logger_name == "tensorboard": + LOGGER.info("No system profiler data available for Tensorboard") + system_metrics_df = None + + self.system_report_fname = self.dirpath / "system_profiler.csv" + self._save_report(system_metrics_df, self.system_report_fname) + return system_metrics_df + + def _save_report(self, df: pd.DataFrame, fname: Path) -> None: + df.to_csv(fname) + + def _save_model_summary(self, model_summary: str, fname: Path) -> None: + with fname.open("w") as f: + f.write(model_summary) + f.close() + + def get_model_summary(self, model: GraphForecaster, example_input_array: np.ndarray) -> str: + + from torchinfo import summary + + # when using flash attention model, we need to convert the input and model to float16 and cuda + # since FlashAttention only supports fp16 and bf16 data type + example_input_array = example_input_array.to(dtype=torch.float16) + example_input_array = example_input_array.to("cuda") + model.half() + model = model.to("cuda") + + summary_str = str( + summary( + model, + input_data=example_input_array, + depth=20, + col_width=16, + col_names=["trainable", "input_size", "output_size", "num_params", "params_percent", "mult_adds"], + row_settings=["var_names"], + verbose=0, + ), + ) + self.model_summary_fname = self.dirpath / "model_summary.txt" + self._save_model_summary(summary_str, self.model_summary_fname) + return summary_str + + @rank_zero_only + def get_speed_profiler_df(self, progressbar: _tqdm) -> pd.DataFrame: + """Computes the speed metrics based on training and validation rates.""" + speed_metrics = {} + + batch_size_tr = self.config.dataloader.batch_size.training + batch_size_val = self.config.dataloader.batch_size.validation + + training_rates_array = np.array(progressbar.training_rates) + speed_metrics["training_avg_throughput"] = training_rates_array.mean() + speed_metrics["training_avg_throughput_per_sample"] = training_rates_array.mean() / batch_size_tr + + validation_rates_array = np.array(progressbar.validation_rates) + speed_metrics["validation_avg_throughput"] = validation_rates_array.mean() + speed_metrics["validation_avg_throughput_per_sample"] = validation_rates_array.mean() / batch_size_val + + # Calculate per_sample metrics + speed_metrics["avg_training_dataloader_throughput"] = ( + 1 / np.array(self.time_profiler.recorded_durations["[_TrainingEpochLoop].train_dataloader_next"]).mean() + ) + speed_metrics["avg_training_dataloader_throughput_per_sample"] = ( + speed_metrics["avg_training_dataloader_throughput"] / batch_size_tr + ) + + speed_metrics["avg_validation_dataloader_throughput"] = ( + 1 / np.array(self.time_profiler.recorded_durations["[_EvaluationLoop].val_next"]).mean() + ) + speed_metrics["avg_validation_dataloader_throughput_per_sample"] = ( + speed_metrics["avg_validation_dataloader_throughput"] / batch_size_val + ) + + if self.time_rows_dict: + speed_metrics.update(self.time_rows_dict) + + speed_profile_df = self.to_df(speed_metrics) + + self.speed_report_fname = self.dirpath / "speed_profiler.csv" + self._save_report(speed_profile_df, self.speed_report_fname) + + return speed_profile_df + + def _save_extra_plots(self) -> None: + if check_torch_version(): + # !it's available for torch >= 2.1 + from torch.cuda._memory_viz import profile_plot + + self.memory_trace_fname = Path(self.dirpath, "memory_trace.html") + with self.memory_trace_fname.open("w") as f: + f.write(profile_plot(self.memory_profiler.profiler)) + + # !it's available for torch >= 2.1 + self.memory_timeline_fname = str(Path(self.dirpath, "memory_timelines.html")) + self.memory_profiler.profiler.export_memory_timeline(self.memory_timeline_fname) + + @rank_zero_only + def get_memory_profiler_df(self) -> pd.DataFrame: + """Retrieves the memory profiler data as a DataFrame. + + Aggregates the results coming from multiple nodes/processes. + + Returns + ------- + pd.DataFrame + Memory profiler data. + """ + if self.config.diagnostics.benchmark_profiler.memory.extra_plots: + self._save_extra_plots() + + self.memory_profiler._delete_profilers() + + if not self.memory_profiler.function_events: + return "" + + data = self.memory_profiler.function_events.key_averages( + group_by_input_shapes=self.memory_profiler._group_by_input_shapes, + ) + table = data.table( + sort_by=self.memory_profiler._sort_by_key, + row_limit=self.memory_profiler._row_limit, + **self.memory_profiler._table_kwargs, + ) # this is a string + + from io import StringIO + + table_main_body = table.split("\n")[:-3] # Remove the last rows + columns = [ + "Name", + "Self CPU %", + "Self CPU", + "CPU total %", + "CPU total", + "CPU time avg", + "Self CUDA", + "Self CUDA %", + "CUDA total", + "CUDA time avg", + "CPU Mem", + "Self CPU Mem", + "CUDA Mem", + "Self CUDA Mem", + "# of Calls", + "Input Shapes", + ] + table_main_body = "\n".join(table_main_body) + memory_df = pd.read_fwf(StringIO(table_main_body), names=columns, skiprows=2) + flag = ["--" not in row for row in memory_df["Name"]] + memory_df = memory_df[flag] + time_rows = [row for row in table.split("\n")[-3:] if row != ""] + if time_rows: + time_rows_dict = {} + for row in time_rows: + key, val = row.split(":") + val = convert_to_seconds(val.strip()) + time_rows_dict[key] = val + self.time_rows_dict = time_rows_dict + + memory_df = memory_df[~memory_df["Name"].isin(time_rows)] + + self.memory_report_fname = self.dirpath / "memory_profiler.csv" + self._save_report(memory_df, self.memory_report_fname) + return memory_df + + +class ProfilerProgressBar(TQDMProgressBar): + """Custom PyTorch Lightning progress bar with profiling functionality. + + Attributes + ---------- + validation_rates : list[float] + List to store validation rates (it/s). + training_rates : list[float] + List to store training rates (it/s). + """ + + def __init__(self): + super().__init__() + self.validation_rates = [] + self.training_rates = [] + + def _extract_rate(self, pbar: _tqdm) -> float: + """Extracts the iteration rate from the progress bar. + + Parameters + ---------- + pbar : tqdm + The progress bar. + + Returns + ------- + float + The iteration rate. + """ + return (pbar.format_dict["n"] - pbar.format_dict["initial"]) / pbar.format_dict["elapsed"] + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + """Appends the rate from the progress bar to the list of 'training_rates'.""" + batch_idx + 1 + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) + if self.train_progress_bar.format_dict["n"] != 0: + self.training_rates.append(self._extract_rate(self.train_progress_bar)) + for logger in self.trainer.loggers: + if isinstance(logger, AnemoiMLflowLogger): + logger.log_metrics({"training_rate": self.training_rates[-1]}, step=trainer.global_step) + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Append rate from the progress bar to the list of 'validation_rates'.""" + super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self.val_progress_bar.format_dict["n"] != 0: + self.validation_rates.append(self._extract_rate(self.val_progress_bar)) + for logger in self.trainer.loggers: + if isinstance(logger, AnemoiMLflowLogger): + logger.log_metrics({"validation_rate": self.validation_rates[-1]}, step=trainer.global_step) diff --git a/src/anemoi/training/distributed/__init__.py b/src/anemoi/training/distributed/__init__.py index 282d6a69..c167afa2 100644 --- a/src/anemoi/training/distributed/__init__.py +++ b/src/anemoi/training/distributed/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/losses/__init__.py b/src/anemoi/training/losses/__init__.py index 33d7fa0a..c167afa2 100644 --- a/src/anemoi/training/losses/__init__.py +++ b/src/anemoi/training/losses/__init__.py @@ -1,8 +1,8 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# diff --git a/src/anemoi/training/losses/combined.py b/src/anemoi/training/losses/combined.py new file mode 100644 index 00000000..11d2b4fe --- /dev/null +++ b/src/anemoi/training/losses/combined.py @@ -0,0 +1,135 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import functools +from typing import Any +from typing import Callable + +import torch + +from anemoi.training.train.forecaster import GraphForecaster + + +class CombinedLoss(torch.nn.Module): + """Combined Loss function.""" + + def __init__( + self, + *extra_losses: dict[str, Any] | Callable, + losses: tuple[dict[str, Any] | Callable] | None = None, + loss_weights: tuple[int, ...], + **kwargs, + ): + """Combined loss function. + + Allows multiple losses to be combined into a single loss function, + and the components weighted. + + If a sub loss function requires additional weightings or code created tensors, + that must be `included_` for this function, and then controlled by the underlying + loss function configuration. + + Parameters + ---------- + losses: tuple[dict[str, Any]| Callable] + Tuple of losses to initialise with `GraphForecaster.get_loss_function`. + Allows for kwargs to be passed, and weighings controlled. + *extra_losses: dict[str, Any] | Callable + Additional arg form of losses to include in the combined loss. + loss_weights : tuple[int, ...] + Weights of each loss function in the combined loss. + kwargs: Any + Additional arguments to pass to the loss functions + + Examples + -------- + >>> CombinedLoss( + {"__target__": "anemoi.training.losses.mse.WeightedMSELoss"}, + loss_weights=(1.0,), + node_weights=node_weights + ) + -------- + >>> CombinedLoss( + losses = [anemoi.training.losses.mse.WeightedMSELoss], + loss_weights=(1.0,), + node_weights=node_weights + ) + Or from the config, + + ``` + training_loss: + __target__: anemoi.training.losses.combined.CombinedLoss + losses: + - __target__: anemoi.training.losses.mse.WeightedMSELoss + - __target__: anemoi.training.losses.mae.WeightedMAELoss + scalars: ['variable'] + loss_weights: [1.0,0.5] + ``` + """ + super().__init__() + + losses = (*(losses or []), *extra_losses) + + assert len(losses) == len(loss_weights), "Number of losses and weights must match" + assert len(losses) > 0, "At least one loss must be provided" + + self.losses = [ + GraphForecaster.get_loss_function(loss, **kwargs) if isinstance(loss, dict) else loss(**kwargs) + for loss in losses + ] + self.loss_weights = loss_weights + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """Calculates the combined loss. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs) + target : torch.Tensor + Target tensor, shape (bs, ensemble, lat*lon, n_outputs) + kwargs: Any + Additional arguments to pass to the loss functions + Will be passed to all loss functions + + Returns + ------- + torch.Tensor + Combined loss + """ + loss = None + for i, loss_fn in enumerate(self.losses): + if loss is not None: + loss += self.loss_weights[i] * loss_fn(pred, target, **kwargs) + else: + loss = self.loss_weights[i] * loss_fn(pred, target, **kwargs) + return loss + + @property + def name(self) -> str: + return "combined_" + "_".join(getattr(loss, "name", loss.__class__.__name__.lower()) for loss in self.losses) + + def __getattr__(self, name: str) -> Callable: + """Allow access to underlying attributes of the loss functions.""" + if not all(hasattr(loss, name) for loss in self.losses): + error_msg = f"Attribute {name} not found in all loss functions" + raise AttributeError(error_msg) + + @functools.wraps(getattr(self.losses[0], name)) + def hidden_func(*args, **kwargs) -> list[Any]: + return [getattr(loss, name)(*args, **kwargs) for loss in self.losses] + + return hidden_func diff --git a/src/anemoi/training/losses/huber.py b/src/anemoi/training/losses/huber.py new file mode 100644 index 00000000..ed5b8d25 --- /dev/null +++ b/src/anemoi/training/losses/huber.py @@ -0,0 +1,104 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging + +import torch + +from anemoi.training.losses.weightedloss import BaseWeightedLoss + +LOGGER = logging.getLogger(__name__) + + +class WeightedHuberLoss(BaseWeightedLoss): + """Node-weighted Huber loss.""" + + name = "whuber" + + def __init__( + self, + node_weights: torch.Tensor, + delta: float = 1.0, + ignore_nans: bool = False, + **kwargs, + ) -> None: + """Node- and feature weighted Huber Loss. + + See `Huber loss `_ for more information. + + Parameters + ---------- + node_weights : torch.Tensor of shape (N, ) + Weight of each node in the loss function + delta : float, optional + Threshold for Huber loss, by default 1.0 + ignore_nans : bool, optional + Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False + """ + super().__init__( + node_weights=node_weights, + ignore_nans=ignore_nans, + **kwargs, + ) + self.delta = delta + + def huber(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Calculate the Huber loss. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs) + target : torch.Tensor + Target tensor, shape (bs, ensemble, lat*lon, n_outputs) + + Returns + ------- + torch.Tensor + Huber loss + """ + diff = torch.abs(pred - target) + return torch.where(diff < self.delta, 0.5 * torch.square(diff), self.delta * (diff - 0.5 * self.delta)) + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + squash: bool = True, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, + ) -> torch.Tensor: + """Calculates the lat-weighted Huber loss. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs) + target : torch.Tensor + Target tensor, shape (bs, ensemble, lat*lon, n_outputs) + squash : bool, optional + Average last dimension, by default True + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None + + Returns + ------- + torch.Tensor + Weighted Huber loss + """ + out = self.huber(pred, target) + + out = self.scale(out, scalar_indices, without_scalars=without_scalars) + + return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/losses/logcosh.py b/src/anemoi/training/losses/logcosh.py new file mode 100644 index 00000000..6f916177 --- /dev/null +++ b/src/anemoi/training/losses/logcosh.py @@ -0,0 +1,97 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from __future__ import annotations + +import logging + +import numpy as np +import torch + +from anemoi.training.losses.weightedloss import BaseWeightedLoss + +LOGGER = logging.getLogger(__name__) + + +class LogCosh(torch.autograd.Function): + """LogCosh custom autograd function.""" + + @staticmethod + def forward(ctx, inp: torch.Tensor) -> torch.Tensor: # noqa: ANN001 + ctx.save_for_backward(inp) + abs_input = torch.abs(inp) + return abs_input + torch.nn.functional.softplus(-2 * abs_input) - np.log(2) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # noqa: ANN001 + (inp,) = ctx.saved_tensors + return grad_output * torch.tanh(inp) + + +class WeightedLogCoshLoss(BaseWeightedLoss): + """Node-weighted LogCosh loss.""" + + name = "wlogcosh" + + def __init__( + self, + node_weights: torch.Tensor, + ignore_nans: bool = False, + **kwargs, + ) -> None: + """Node- and feature weighted LogCosh Loss. + + Parameters + ---------- + node_weights : torch.Tensor of shape (N, ) + Weight of each node in the loss function + ignore_nans : bool, optional + Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False + + """ + super().__init__( + node_weights=node_weights, + ignore_nans=ignore_nans, + **kwargs, + ) + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + squash: bool = True, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, + ) -> torch.Tensor: + """Calculates the lat-weighted LogCosh loss. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs) + target : torch.Tensor + Target tensor, shape (bs, ensemble, lat*lon, n_outputs) + squash : bool, optional + Average last dimension, by default True + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None + + Returns + ------- + torch.Tensor + Weighted LogCosh loss + + """ + out = LogCosh.apply(pred - target) + out = self.scale(out, scalar_indices, without_scalars=without_scalars) + return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/losses/mae.py b/src/anemoi/training/losses/mae.py new file mode 100644 index 00000000..b2112d98 --- /dev/null +++ b/src/anemoi/training/losses/mae.py @@ -0,0 +1,83 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from __future__ import annotations + +import logging + +import torch + +from anemoi.training.losses.weightedloss import BaseWeightedLoss + +LOGGER = logging.getLogger(__name__) + + +class WeightedMAELoss(BaseWeightedLoss): + """Node-weighted MAE loss.""" + + name = "wmae" + + def __init__( + self, + node_weights: torch.Tensor, + ignore_nans: bool = False, + **kwargs, + ) -> None: + """Node- and feature weighted MAE Loss. + + Also known as the Weighted L1 loss. + + Parameters + ---------- + node_weights : torch.Tensor of shape (N, ) + Weight of each node in the loss function + ignore_nans : bool, optional + Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False + + """ + super().__init__( + node_weights=node_weights, + ignore_nans=ignore_nans, + **kwargs, + ) + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + squash: bool = True, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, + ) -> torch.Tensor: + """Calculates the lat-weighted MAE loss. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs) + target : torch.Tensor + Target tensor, shape (bs, ensemble, lat*lon, n_outputs) + squash : bool, optional + Average last dimension, by default True + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None + + + Returns + ------- + torch.Tensor + Weighted MAE loss + """ + out = torch.abs(pred - target) + out = self.scale(out, scalar_indices, without_scalars=without_scalars) + return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/losses/mse.py b/src/anemoi/training/losses/mse.py index 88ad0d0b..c30f8b9d 100644 --- a/src/anemoi/training/losses/mse.py +++ b/src/anemoi/training/losses/mse.py @@ -13,80 +13,68 @@ import logging import torch -from torch import nn + +from anemoi.training.losses.weightedloss import BaseWeightedLoss LOGGER = logging.getLogger(__name__) -class WeightedMSELoss(nn.Module): - """Latitude-weighted MSE loss.""" +class WeightedMSELoss(BaseWeightedLoss): + """Node-weighted MSE loss.""" + + name = "wmse" def __init__( self, node_weights: torch.Tensor, - data_variances: torch.Tensor | None = None, - ignore_nans: bool | None = False, + ignore_nans: bool = False, + **kwargs, ) -> None: - """Latitude- and (inverse-)variance-weighted MSE Loss. + """Node- and feature weighted MSE Loss. Parameters ---------- node_weights : torch.Tensor of shape (N, ) Weight of each node in the loss function - data_variances : Optional[torch.Tensor], optional - precomputed, per-variable stepwise variance estimate, by default None ignore_nans : bool, optional Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False """ - super().__init__() - - self.avg_function = torch.nanmean if ignore_nans else torch.mean - self.sum_function = torch.nansum if ignore_nans else torch.sum - - self.register_buffer("weights", node_weights, persistent=True) - if data_variances is not None: - self.register_buffer("ivar", data_variances, persistent=True) + super().__init__( + node_weights=node_weights, + ignore_nans=ignore_nans, + **kwargs, + ) def forward( self, pred: torch.Tensor, target: torch.Tensor, squash: bool = True, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, ) -> torch.Tensor: """Calculates the lat-weighted MSE loss. Parameters ---------- pred : torch.Tensor - Prediction tensor, shape (bs, lat*lon, n_outputs) + Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs) target : torch.Tensor - Target tensor, shape (bs, lat*lon, n_outputs) + Target tensor, shape (bs, ensemble, lat*lon, n_outputs) squash : bool, optional Average last dimension, by default True + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None Returns ------- torch.Tensor Weighted MSE loss - """ out = torch.square(pred - target) - - # Use variances if available - if hasattr(self, "ivar"): - out *= self.ivar - - # Squash by last dimension - if squash: - out = self.avg_function(out, dim=-1) - # Weight by area - out *= self.weights.expand_as(out) - out /= self.sum_function(self.weights.expand_as(out)) - return self.sum_function(out) - - # Weight by area - out *= self.weights[..., None].expand_as(out) - # keep last dimension (variables) when summing weights - out /= self.sum_function(self.weights[..., None].expand_as(out), axis=(0, 1, 2)) - return self.sum_function(out, axis=(0, 1, 2)) + out = self.scale(out, scalar_indices, without_scalars=without_scalars) + return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/losses/rmse.py b/src/anemoi/training/losses/rmse.py new file mode 100644 index 00000000..6c97344a --- /dev/null +++ b/src/anemoi/training/losses/rmse.py @@ -0,0 +1,84 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from __future__ import annotations + +import logging + +import torch + +from anemoi.training.losses.mse import BaseWeightedLoss + +LOGGER = logging.getLogger(__name__) + + +class WeightedRMSELoss(BaseWeightedLoss): + """Node-weighted RMSE loss.""" + + name = "wrmse" + + def __init__( + self, + node_weights: torch.Tensor, + ignore_nans: bool = False, + **kwargs, + ) -> None: + """Node- and (inverse-)variance-weighted RMSE Loss. + + Parameters + ---------- + node_weights : torch.Tensor of shape (N, ) + Weight of each node in the loss function + ignore_nans : bool, optional + Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False + """ + super().__init__( + node_weights=node_weights, + ignore_nans=ignore_nans, + **kwargs, + ) + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + squash: bool = True, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, + ) -> torch.Tensor: + """Calculates the lat-weighted RMSE loss. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs) + target : torch.Tensor + Target tensor, shape (bs, ensemble, lat*lon, n_outputs) + squash : bool, optional + Average last dimension, by default True + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None + + Returns + ------- + torch.Tensor + Weighted RMSE loss + """ + mse = super().forward( + pred=pred, + target=target, + squash=squash, + scalar_indices=scalar_indices, + without_scalars=without_scalars, + ) + return torch.sqrt(mse) diff --git a/src/anemoi/training/losses/utils.py b/src/anemoi/training/losses/utils.py index 5ddef3d6..e98e0bfe 100644 --- a/src/anemoi/training/losses/utils.py +++ b/src/anemoi/training/losses/utils.py @@ -11,10 +11,17 @@ from __future__ import annotations import logging +import uuid +from typing import TYPE_CHECKING +from typing import Callable +from typing import Union import torch from torch import nn +if TYPE_CHECKING: + from collections.abc import Sequence + LOGGER = logging.getLogger(__name__) @@ -27,7 +34,7 @@ def grad_scaler( Uses the formula in https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2 - Use .register_full_backward_hook(grad_scaler, prepend=False) to register this hook. + Use .register_full_backward_hook(grad_scalar, prepend=False) to register this hook. Parameters ---------- @@ -52,3 +59,451 @@ def grad_scaler( (channels * channel_weights) / torch.sum(channel_weights, dim=-1, keepdim=True) * grad_in[0] ) # rescaled gradient return new_grad_in, grad_in[1] + + +TENSOR_SPEC = tuple[Union[int, tuple[int]], torch.Tensor] + + +class Shape: + """Shape resolving object.""" + + def __init__(self, func: Callable[[int], int]): + self.func = func + + def __getitem__(self, dimension: int) -> int: + return self.func(dimension) + + +# TODO(Harrison Cook): Consider moving this to subclass from a pytorch object and allow for device moving completely +class ScaleTensor: + """Dynamically resolved tensor scaling class. + + Allows a user to specify a scalar and the dimensions it should be applied to. + The class will then enforce that additional scalars are compatible with the specified dimensions. + + When `get_scalar` or `scale` is called, the class will return the product of all scalars, resolved + to the dimensional size of the input tensor. + + Additionally, the class can be subsetted to only return a subset of the scalars, but + only from those given names. + + Examples + -------- + >>> tensor = torch.randn(3, 4, 5) + >>> scalars = ScaleTensor((0, torch.randn(3)), (1, torch.randn(4))) + >>> scaled_tensor = scalars.scale(tensor) + >>> scalars.get_scalar(tensor.ndim).shape + torch.Size([3, 4, 1]) + >>> scalars.add_scalar(-1, torch.randn(5)) + >>> scalars.get_scalar(tensor.ndim).shape + torch.Size([3, 4, 5]) + """ + + tensors: dict[str, TENSOR_SPEC] + _specified_dimensions: dict[str, tuple[int]] + + def __init__( + self, + scalars: dict[str, TENSOR_SPEC] | TENSOR_SPEC | None = None, + *tensors: TENSOR_SPEC, + **named_tensors: dict[str, TENSOR_SPEC], + ): + """ScaleTensor constructor. + + Parameters + ---------- + scalars : dict[str, TENSOR_SPEC] | TENSOR_SPEC | None, optional + Scalars to initalise with, by default None + tensors : TENSOR_SPEC + Args form of (dimension, tensor) to add to the scalars + Will be given a random uuid name + named_tensors : dict[str, TENSOR_SPEC] + Kwargs form of {name: (dimension, tensor)} to add to the scalars + """ + self.tensors = {} + self._specified_dimensions = {} + + named_tensors.update(scalars or {}) + self.add(named_tensors) + + for tensor_spec in tensors: + self.add_scalar(*tensor_spec) + + @property + def shape(self) -> Shape: + """Get the shape of the scale tensor. + + Returns a Shape object to be indexed, + Will only resolve those dimensions specified in the tensors. + """ + + def get_dim_shape(dimension: int) -> int: + for dim_assign, tensor in self.tensors.values(): + if isinstance(dim_assign, tuple) and dimension in dim_assign: + return tensor.shape[list(dim_assign).index(dimension)] + + unique_dims = {dim for dim_assign in self._specified_dimensions.values() for dim in dim_assign} + error_msg = ( + f"Could not find shape of dimension {dimension}. " + f"Tensors are only specified for dimensions {list(unique_dims)}." + ) + raise IndexError(error_msg) + + return Shape(get_dim_shape) + + def validate_scalar(self, dimension: int | tuple[int], scalar: torch.Tensor) -> None: + """Check if the scalar is compatible with the given dimension. + + Parameters + ---------- + dimension : int | tuple[int] + Dimensions to check `scalar` against + scalar : torch.Tensor + Scalar tensor to check + + Raises + ------ + ValueError + If the scalar is not compatible with the given dimension + """ + if isinstance(dimension, int): + dimension = [dimension] + + for scalar_dim, dim in enumerate(dimension): + if dim not in self or scalar.shape[scalar_dim] == 1 or self.shape[dim] == 1: + continue + + if self.shape[dim] != scalar.shape[scalar_dim]: + error_msg = ( + f"Incoming scalar shape {scalar.shape} at dimension {scalar_dim} " + f"does not match shape of saved scalar. Expected {self.shape[dim]}" + ) + raise ValueError(error_msg) + + def add_scalar( + self, + dimension: int | tuple[int], + scalar: torch.Tensor, + *, + name: str | None = None, + ) -> None: + """Add new scalar to be applied along `dimension`. + + Dimension can be a single int even for a multi-dimensional scalar, + in this case the dimensions are assigned as a range starting from the given int. + Negative indexes are also valid, and will be resolved against the tensor's ndim. + + Parameters + ---------- + dimension : int | tuple[int] + Dimension/s to apply the scalar to + scalar : torch.Tensor + Scalar tensor to apply + name : str | None, optional + Name of the scalar, by default None + """ + if not isinstance(scalar, torch.Tensor): + scalar = torch.tensor([scalar]) if isinstance(scalar, (int, float)) else torch.tensor(scalar) + + if isinstance(dimension, int): + if len(scalar.shape) == 1: + dimension = (dimension,) + else: + dimension = tuple(dimension + i for i in range(len(scalar.shape))) + else: + dimension = tuple(dimension) + + if name is None: + name = str(uuid.uuid4()) + + if name in self.tensors: + msg = f"Scalar {name!r} already exists in scalars." + raise ValueError(msg) + + try: + self.validate_scalar(dimension, scalar) + except ValueError as e: + error_msg = f"Validating tensor {name!r} raised an error." + raise ValueError(error_msg) from e + + self.tensors[name] = (dimension, scalar) + self._specified_dimensions[name] = dimension + + def update_scalar(self, name: str, scalar: torch.Tensor, *, override: bool = False) -> None: + """Update an existing scalar maintaining original dimensions. + + If `override` is False, the scalar must be valid against the original dimensions. + If `override` is True, the scalar will be updated regardless of validity against original scalar. + + Parameters + ---------- + name : str + Name of the scalar to update + scalar : torch.Tensor + New scalar tensor + override : bool, optional + Whether to override the scalar ignoring dimension compatibility, by default False + """ + if name not in self.tensors: + msg = f"Scalar {name!r} not found in scalars." + raise ValueError(msg) + + dimension = self.tensors[name][0] + + if not override: + self.validate_scalar(dimension, scalar) + + original_scalar = self.tensors.pop(name) + original_dimension = self._specified_dimensions.pop(name) + + try: + self.add_scalar(dimension, scalar, name=name) + except ValueError: + self.tensors[name] = original_scalar + self._specified_dimensions[name] = original_dimension + raise + + def add(self, new_scalars: dict[str, TENSOR_SPEC] | list[TENSOR_SPEC] | None = None, **kwargs) -> None: + """Add multiple scalars to the existing scalars. + + Parameters + ---------- + new_scalars : dict[str, TENSOR_SPEC] | list[TENSOR_SPEC] | None, optional + Scalars to add, see `add_scalar` for more info, by default None + **kwargs: + Kwargs form of {name: (dimension, tensor)} to add to the scalars + """ + if isinstance(new_scalars, list): + for tensor_spec in new_scalars: + self.add_scalar(*tensor_spec) + else: + kwargs.update(new_scalars or {}) + for name, tensor_spec in kwargs.items(): + self.add_scalar(*tensor_spec, name=name) + + def update(self, updated_scalars: dict[str, torch.Tensor] | None = None, override: bool = False, **kwargs) -> None: + """Update multiple scalars in the existing scalars. + + If `override` is False, the scalar must be valid against the original dimensions. + If `override` is True, the scalar will be updated regardless of shape. + + Parameters + ---------- + updated_scalars : dict[str, torch.Tensor] | None, optional + Scalars to update, referenced by name, by default None + override : bool, optional + Whether to override the scalar ignoring dimension compatibility, by default False + **kwargs: + Kwargs form of {name: tensor} to update in the scalars + """ + kwargs.update(updated_scalars or {}) + for name, tensor in kwargs.items(): + self.update_scalar(name, tensor, override=override) + + def subset(self, scalars: str | Sequence[str]) -> ScaleTensor: + """Get subset of the scalars, filtering by name. + + See `.subset_by_dim` for subsetting by affected dimensions. + + Parameters + ---------- + scalars : str | Sequence[str] + Name/s of the scalars to get + + Returns + ------- + ScaleTensor + Subset of self + """ + if isinstance(scalars, str): + scalars = [scalars] + return ScaleTensor(**{name: self.tensors[name] for name in scalars}) + + def without(self, scalars: str | Sequence[str]) -> ScaleTensor: + """Get subset of the scalars, filtering out by name. + + Parameters + ---------- + scalars : str | Sequence[str] + Name/s of the scalars to exclude + + Returns + ------- + ScaleTensor + Subset of self + """ + if isinstance(scalars, str): + scalars = [scalars] + return ScaleTensor(**{name: tensor for name, tensor in self.tensors.items() if name not in scalars}) + + def subset_by_dim(self, dimensions: int | Sequence[int]) -> ScaleTensor: + """Get subset of the scalars, filtering by dimension. + + See `.subset` for subsetting by name. + + Parameters + ---------- + dimensions : int | Sequence[int] + Dimensions to get scalars of + + Returns + ------- + ScaleTensor + Subset of self + """ + subset_scalars: dict[str, TENSOR_SPEC] = {} + + if isinstance(dimensions, int): + dimensions = (dimensions,) + + for name, (dim, scalar) in self.tensors.items(): + if isinstance(dim, int): + dim = (dim,) + if len(set(dimensions).intersection(dim)) > 0: + subset_scalars[name] = (dim, scalar) + + return ScaleTensor(**subset_scalars) + + def without_by_dim(self, dimensions: int | Sequence[int]) -> ScaleTensor: + """Get subset of the scalars, filtering out by dimension. + + Parameters + ---------- + dimensions : int | Sequence[int] + Dimensions to exclude scalars of + + Returns + ------- + ScaleTensor + Subset of self + """ + subset_scalars: dict[str, TENSOR_SPEC] = {} + + if isinstance(dimensions, int): + dimensions = (dimensions,) + + for name, (dim, scalar) in self.tensors.items(): + if isinstance(dim, int): + dim = (dim,) + if len(set(dimensions).intersection(dim)) == 0: + subset_scalars[name] = (dim, scalar) + + return ScaleTensor(**subset_scalars) + + def resolve(self, ndim: int) -> ScaleTensor: + """Resolve relative indexes in scalars by associating against ndim. + + i.e. if a scalar was given as effecting dimension -1, + and `ndim` was provided as 4, the scalar will be fixed + to effect dimension 3. + + Parameters + ---------- + ndim : int + Number of dimensions to resolve relative indexing against + + Returns + ------- + ScaleTensor + ScaleTensor with all relative indexes resolved + """ + resolved_scalars: dict[str, TENSOR_SPEC] = {} + + for name, (dims, scalar) in self.tensors.items(): + if any(d < 0 for d in dims): + dims = [d if d >= 0 else ndim + d for d in dims] + resolved_scalars[name] = (dims, scalar) + + return ScaleTensor(**resolved_scalars) + + def scale(self, tensor: torch.Tensor) -> torch.Tensor: + """Scale a given tensor by the scalars. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor to scale + + Returns + ------- + torch.Tensor + Scaled tensor + """ + return tensor * self.get_scalar(tensor.ndim, device=tensor.device) + + def get_scalar(self, ndim: int, device: str | None = None) -> torch.Tensor: + """Get completely resolved scalar tensor. + + Parameters + ---------- + ndim : int + Number of dimensions of the tensor to resolve the scalars to + Used to resolve relative indices, and add singleton dimensions + device: str | None, optional + Device to move the scalar to, by default None + + Returns + ------- + torch.Tensor + Scalar tensor + + Raises + ------ + ValueError + If resolving relative indices is invalid + """ + complete_scalar = None + + tensors = self.resolve(ndim).tensors + + for dims, scalar in tensors.values(): + missing_dims = [d for d in range(ndim) if d not in dims] + reshape = [1] * len(missing_dims) + reshape.extend(scalar.shape) + + reshaped_scalar = scalar.reshape(reshape) + reshaped_scalar = torch.moveaxis(reshaped_scalar, list(range(ndim)), (*missing_dims, *dims)) + + complete_scalar = reshaped_scalar if complete_scalar is None else complete_scalar * reshaped_scalar + + complete_scalar = torch.ones(1) if complete_scalar is None else complete_scalar + + if device is not None: + return complete_scalar.to(device) + return complete_scalar + + def to(self, *args, **kwargs) -> None: + """Move scalars inplace.""" + for name, (dims, tensor) in self.tensors.items(): + self.tensors[name] = (dims, tensor.to(*args, **kwargs)) + + def __mul__(self, tensor: torch.Tensor) -> torch.Tensor: + return self.scale(tensor) + + def __rmul__(self, tensor: torch.Tensor) -> torch.Tensor: + return self.scale(tensor) + + def __repr__(self): + return ( + f"ScalarTensor:\n - With tensors : {list(self.tensors.keys())}\n" + f" - In dimensions : {list(self._specified_dimensions.values())}" + ) + + def __contains__(self, dimension: int | tuple[int] | str) -> bool: + """Check if either scalar by name or dimension by int/tuple is being scaled.""" + if isinstance(dimension, tuple): + return dimension in self._specified_dimensions.values() + if isinstance(dimension, str): + return dimension in self.tensors + + result = False + for dim_assign, _ in self.tensors.values(): + result = dimension in dim_assign or result + return result + + def __len__(self): + return len(self.tensors) + + def __iter__(self): + """Iterate over tensors.""" + return iter(self.tensors) diff --git a/src/anemoi/training/losses/weightedloss.py b/src/anemoi/training/losses/weightedloss.py new file mode 100644 index 00000000..0deccc9d --- /dev/null +++ b/src/anemoi/training/losses/weightedloss.py @@ -0,0 +1,248 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from __future__ import annotations + +import functools +import logging +from abc import ABC +from abc import abstractmethod + +import torch +from torch import nn + +from anemoi.training.losses.utils import ScaleTensor + +LOGGER = logging.getLogger(__name__) + + +class BaseWeightedLoss(nn.Module, ABC): + """Node-weighted general loss.""" + + scalar: ScaleTensor + + def __init__( + self, + node_weights: torch.Tensor, + ignore_nans: bool = False, + ) -> None: + """Node- and feature_weighted Loss. + + Exposes: + - self.avg_function: torch.nanmean or torch.mean + - self.sum_function: torch.nansum or torch.sum + depending on the value of `ignore_nans` + + Registers: + - self.node_weights: torch.Tensor of shape (N, ) + - self.scalar: ScaleTensor modified with `add_scalar` and `update_scalar` + + Parameters + ---------- + node_weights : torch.Tensor of shape (N, ) + Weight of each node in the loss function + ignore_nans : bool, optional + Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False + + """ + super().__init__() + + self.scalar = ScaleTensor() + + self.avg_function = torch.nanmean if ignore_nans else torch.mean + self.sum_function = torch.nansum if ignore_nans else torch.sum + + self.register_buffer("node_weights", node_weights, persistent=True) + + @functools.wraps(ScaleTensor.add_scalar, assigned=("__doc__", "__annotations__")) + def add_scalar(self, dimension: int | tuple[int], scalar: torch.Tensor, *, name: str | None = None) -> None: + self.scalar.add_scalar(dimension=dimension, scalar=scalar, name=name) + + @functools.wraps(ScaleTensor.update_scalar, assigned=("__doc__", "__annotations__")) + def update_scalar(self, name: str, scalar: torch.Tensor, *, override: bool = False) -> None: + self.scalar.update_scalar(name=name, scalar=scalar, override=override) + + def scale( + self, + x: torch.Tensor, + scalar_indices: tuple[int, ...] | None = None, + *, + without_scalars: list[str] | list[int] | None = None, + ) -> torch.Tensor: + """Scale a tensor by the variable_scaling. + + Parameters + ---------- + x : torch.Tensor + Tensor to be scaled, shape (bs, ensemble, lat*lon, n_outputs) + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None. + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None + + Returns + ------- + torch.Tensor + Scaled error tensor + """ + if len(self.scalar) == 0: + return x + + scale_tensor = self.scalar + if without_scalars is not None and len(without_scalars) > 0: + if isinstance(without_scalars[0], str): + scale_tensor = self.scalar.without(without_scalars) + else: + scale_tensor = self.scalar.without_by_dim(without_scalars) + + scalar = scale_tensor.get_scalar(x.ndim).to(x) + + if scalar_indices is None: + return x * scalar + return x * scalar[scalar_indices] + + def scale_by_node_weights(self, x: torch.Tensor, squash: bool = True) -> torch.Tensor: + """Scale a tensor by the node_weights. + + Equivalent to reducing and averaging accordingly across all + dimensions of the tensor. + + Parameters + ---------- + x : torch.Tensor + Tensor to be scaled, shape (bs, ensemble, lat*lon, n_outputs) + squash : bool, optional + Average last dimension, by default True + If False, the loss returned of shape (n_outputs) + + Returns + ------- + torch.Tensor + Scaled error tensor + """ + # Squash by last dimension + if squash: + x = self.avg_function(x, dim=-1) + # Weight by area + x *= self.node_weights.expand_as(x) + x /= self.sum_function(self.node_weights.expand_as(x)) + return self.sum_function(x) + + # Weight by area, due to weighting construction is analagous to a mean + x *= self.node_weights[..., None].expand_as(x) + # keep last dimension (variables) when summing weights + x /= self.sum_function(self.node_weights[..., None].expand_as(x), dim=(0, 1, 2)) + return self.sum_function(x, dim=(0, 1, 2)) + + @abstractmethod + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + squash: bool = True, + *, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, + ) -> torch.Tensor: + """Calculates the lat-weighted scaled loss. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs) + target : torch.Tensor + Target tensor, shape (bs, ensemble, lat*lon, n_outputs) + squash : bool, optional + Average last dimension, by default True + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None + + Returns + ------- + torch.Tensor + Weighted loss + """ + out = pred - target + + out = self.scale(out, scalar_indices, without_scalars=without_scalars) + + return self.scale_by_node_weights(out, squash) + + @property + def name(self) -> str: + """Used for logging identification purposes.""" + return self.__class__.__name__.lower() + + +class FunctionalWeightedLoss(BaseWeightedLoss): + """WeightedLoss which a user can subclass and provide `calculate_difference`. + + `calculate_difference` should calculate the difference between the prediction and target. + All scaling and weighting is handled by the parent class. + + Example: + -------- + ```python + class MyLoss(FunctionalWeightedLoss): + def calculate_difference(self, pred, target): + return pred - target + ``` + """ + + def __init__( + self, + node_weights: torch.Tensor, + ignore_nans: bool = False, + ) -> None: + super().__init__(node_weights, ignore_nans) + + @abstractmethod + def calculate_difference(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Calculate Difference between prediction and target.""" + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + squash: bool = True, + *, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, + ) -> torch.Tensor: + """Calculates the lat-weighted scaled loss. + + Parameters + ---------- + pred : torch.Tensor + Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs) + target : torch.Tensor + Target tensor, shape (bs, ensemble, lat*lon, n_outputs) + squash : bool, optional + Average last dimension, by default True + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None + + + Returns + ------- + torch.Tensor + Weighted loss + """ + out = self.calculate_difference(pred, target) + + out = self.scale(out, scalar_indices, without_scalars=without_scalars) + return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/train/__init__.py b/src/anemoi/training/train/__init__.py index 282d6a69..c167afa2 100644 --- a/src/anemoi/training/train/__init__.py +++ b/src/anemoi/training/train/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 36f6fa9a..80459b8f 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -12,7 +12,10 @@ import math import os from collections import defaultdict +from collections.abc import Generator from collections.abc import Mapping +from typing import Optional +from typing import Union import numpy as np import pytorch_lightning as pl @@ -29,8 +32,8 @@ from torch.utils.checkpoint import checkpoint from torch_geometric.data import HeteroData -from anemoi.training.losses.mse import WeightedMSELoss from anemoi.training.losses.utils import grad_scaler +from anemoi.training.losses.weightedloss import BaseWeightedLoss from anemoi.training.utils.jsonify import map_config_to_primitives from anemoi.training.utils.masks import Boolean1DMask from anemoi.training.utils.masks import NoOutputMask @@ -83,22 +86,35 @@ def __init__( self.save_hyperparameters() self.latlons_data = graph_data[config.graph.data].x - self.loss_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() + self.node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() if config.model.get("output_mask", None) is not None: self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask]) else: self.output_mask = NoOutputMask() - self.loss_weights = self.output_mask.apply(self.loss_weights, dim=0, fill_value=0.0) + self.node_weights = self.output_mask.apply(self.node_weights, dim=0, fill_value=0.0) self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled - self.metric_ranges, self.metric_ranges_validation, loss_scaling = self.metrics_loss_scaling( - config, - data_indices, - ) - self.loss = WeightedMSELoss(node_weights=self.loss_weights, data_variances=loss_scaling) - self.metrics = WeightedMSELoss(node_weights=self.loss_weights, ignore_nans=True) + variable_scaling = self.get_variable_scaling(config, data_indices) + + _, self.val_metric_ranges = self.get_val_metric_ranges(config, data_indices) + + # Kwargs to pass to the loss function + loss_kwargs = {"node_weights": self.node_weights} + # Scalars to include in the loss function, must be of form (dim, scalar) + scalars = {"variable": (-1, variable_scaling)} + + self.loss = self.get_loss_function(config.training.training_loss, scalars=scalars, **loss_kwargs) + + assert isinstance(self.loss, torch.nn.Module) and not isinstance( + self.loss, + torch.nn.ModuleList, + ), f"Loss function must be a `torch.nn.Module`, not a {type(self.loss).__name__!r}" + + self.metrics = self.get_loss_function(config.training.validation_metrics, scalars=scalars, **loss_kwargs) + if not isinstance(self.metrics, torch.nn.ModuleList): + self.metrics = torch.nn.ModuleList([self.metrics]) if config.training.loss_gradient_scaling: self.loss.register_full_backward_hook(grad_scaler, prepend=False) @@ -125,8 +141,6 @@ def __init__( LOGGER.debug("Rollout max : %d", self.rollout_max) LOGGER.debug("Multistep: %d", self.multi_step) - self.enable_plot = config.diagnostics.plot.enabled - self.model_comm_group_id = int(os.environ.get("SLURM_PROCID", "0")) // config.hardware.num_gpus_per_model self.model_comm_group_rank = int(os.environ.get("SLURM_PROCID", "0")) % config.hardware.num_gpus_per_model self.model_comm_num_groups = math.ceil( @@ -136,48 +150,94 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x, self.model_comm_group) + # Future import breaks other type hints TODO Harrison Cook @staticmethod - def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> tuple[dict, torch.Tensor]: - metric_ranges = defaultdict(list) - metric_ranges_validation = defaultdict(list) - loss_scaling = ( - np.ones((len(data_indices.internal_data.output.full),), dtype=np.float32) - * config.training.loss_scaling.default - ) + def get_loss_function( + config: DictConfig, + scalars: Union[dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]], None] = None, # noqa: FA100 + **kwargs, + ) -> Union[torch.nn.Module, torch.nn.ModuleList]: # noqa: FA100 + """Get loss functions from config. - pressure_level = instantiate(config.training.pressure_level_scaler) + Can be ModuleList if multiple losses are specified. - LOGGER.info( - "Pressure level scaling: use scaler %s with slope %.4f and minimum %.2f", - type(pressure_level).__name__, - pressure_level.slope, - pressure_level.minimum, - ) + Parameters + ---------- + config : DictConfig + Loss function configuration, should include `scalars` if scalars are to be added to the loss function. + scalars : Union[dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]], None], optional + Scalars which can be added to the loss function. Defaults to None., by default None + If a scalar is to be added to the loss, ensure it is in `scalars` in the loss config + E.g. + If `scalars: ['variable']` is set in the config, and `variable` in `scalars` + `variable` will be added to the scalar of the loss function. + kwargs : Any + Additional arguments to pass to the loss function + + Returns + ------- + Union[torch.nn.Module, torch.nn.ModuleList] + Loss function, or list of metrics + + Raises + ------ + TypeError + If not a subclass of `BaseWeightedLoss` + ValueError + If scalar is not found in valid scalars + """ + config_container = OmegaConf.to_container(config, resolve=False) + if isinstance(config_container, list): + return torch.nn.ModuleList( + [ + GraphForecaster.get_loss_function( + OmegaConf.create(loss_config), + scalars=scalars, + **kwargs, + ) + for loss_config in config + ], + ) + + loss_config = OmegaConf.to_container(config, resolve=True) + scalars_to_include = loss_config.pop("scalars", []) + + # Instantiate the loss function with the loss_init_config + loss_function = instantiate(loss_config, **kwargs) + + if not isinstance(loss_function, BaseWeightedLoss): + error_msg = f"Loss must be a subclass of 'BaseWeightedLoss', not {type(loss_function)}" + raise TypeError(error_msg) + + for key in scalars_to_include: + if key not in scalars or []: + error_msg = f"Scalar {key!r} not found in valid scalars: {list(scalars.keys())}" + raise ValueError(error_msg) + loss_function.add_scalar(*scalars[key], name=key) + + return loss_function + + @staticmethod + def get_val_metric_ranges(config: DictConfig, data_indices: IndexCollection) -> tuple[dict, dict]: + + metric_ranges = defaultdict(list) + metric_ranges_validation = defaultdict(list) for key, idx in data_indices.internal_model.output.name_to_index.items(): - # Split pressure levels on "_" separator split = key.split("_") if len(split) > 1 and split[-1].isdigit(): - # Create grouped metrics for pressure levels (e.g. Q, T, U, V, etc.) for logger + # Group metrics for pressure levels (e.g., Q, T, U, V, etc.) metric_ranges[f"pl_{split[0]}"].append(idx) - # Create pressure levels in loss scaling vector - if split[0] in config.training.loss_scaling.pl: - loss_scaling[idx] = config.training.loss_scaling.pl[split[0]] * pressure_level.scaler( - int(split[-1]), - ) - else: - LOGGER.debug("Parameter %s was not scaled.", key) else: metric_ranges[f"sfc_{key}"].append(idx) - # Create surface variables in loss scaling vector - if key in config.training.loss_scaling.sfc: - loss_scaling[idx] = config.training.loss_scaling.sfc[key] - else: - LOGGER.debug("Parameter %s was not scaled.", key) - # Create specific metrics from hydra to log in logger + + # Specific metrics from hydra to log in logger if key in config.training.metrics: metric_ranges[key] = [idx] - loss_scaling = torch.from_numpy(loss_scaling) + + # Add the full list of output indices + metric_ranges["all"] = data_indices.internal_model.output.full.tolist() + # metric for validation, after postprocessing for key, idx in data_indices.model.output.name_to_index.items(): # Split pressure levels on "_" separator @@ -190,7 +250,47 @@ def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> t # Create specific metrics from hydra to log in logger if key in config.training.metrics: metric_ranges_validation[key] = [idx] - return metric_ranges, metric_ranges_validation, loss_scaling + + return metric_ranges, metric_ranges_validation + + @staticmethod + def get_variable_scaling( + config: DictConfig, + data_indices: IndexCollection, + ) -> torch.Tensor: + variable_loss_scaling = ( + np.ones((len(data_indices.internal_data.output.full),), dtype=np.float32) + * config.training.variable_loss_scaling.default + ) + pressure_level = instantiate(config.training.pressure_level_scaler) + + LOGGER.info( + "Pressure level scaling: use scaler %s with slope %.4f and minimum %.2f", + type(pressure_level).__name__, + pressure_level.slope, + pressure_level.minimum, + ) + + for key, idx in data_indices.internal_model.output.name_to_index.items(): + split = key.split("_") + if len(split) > 1 and split[-1].isdigit(): + # Apply pressure level scaling + if split[0] in config.training.variable_loss_scaling.pl: + variable_loss_scaling[idx] = config.training.variable_loss_scaling.pl[ + split[0] + ] * pressure_level.scaler( + int(split[-1]), + ) + else: + LOGGER.debug("Parameter %s was not scaled.", key) + else: + # Apply surface variable scaling + if key in config.training.variable_loss_scaling.sfc: + variable_loss_scaling[idx] = config.training.variable_loss_scaling.sfc[key] + else: + LOGGER.debug("Parameter %s was not scaled.", key) + + return torch.from_numpy(variable_loss_scaling) def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: LOGGER.debug("set_model_comm_group: %s", model_comm_group) @@ -223,17 +323,43 @@ def advance_input( ] return x - def _step( + def rollout_step( self, batch: torch.Tensor, - batch_idx: int, + rollout: Optional[int] = None, # noqa: FA100 + training_mode: bool = True, validation_mode: bool = False, - ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: - del batch_idx - loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) + ) -> Generator[tuple[Union[torch.Tensor, None], dict, list], None, None]: # noqa: FA100 + """Rollout step for the forecaster. + + Will run pre_processors on batch, but not post_processors on predictions. + + Parameters + ---------- + batch : torch.Tensor + Batch to use for rollout + rollout : Optional[int], optional + Number of times to rollout for, by default None + If None, will use self.rollout + training_mode : bool, optional + Whether in training mode and to calculate the loss, by default True + If False, loss will be None + validation_mode : bool, optional + Whether in validation mode, and to calculate validation metrics, by default False + If False, metrics will be empty + + Yields + ------ + Generator[tuple[Union[torch.Tensor, None], dict, list], None, None] + Loss value, metrics, and predictions (per step) + + Returns + ------- + None + None + """ # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) - metrics = {} # start rollout of preprocessed batch x = batch[ @@ -242,29 +368,52 @@ def _step( ..., self.data_indices.internal_data.input.full, ] # (bs, multi_step, latlon, nvar) + msg = ( + "Batch length not sufficient for requested multi_step length!" + f", {batch.shape[1]} !>= {rollout + self.multi_step}" + ) + assert batch.shape[1] >= rollout + self.multi_step, msg - y_preds = [] - for rollout_step in range(self.rollout): + for rollout_step in range(rollout or self.rollout): # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) y_pred = self(x) y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.internal_data.output.full] # y includes the auxiliary variables, so we must leave those out when computing the loss - loss += checkpoint(self.loss, y_pred, y, use_reentrant=False) + loss = checkpoint(self.loss, y_pred, y, use_reentrant=False) if training_mode else None x = self.advance_input(x, y_pred, batch, rollout_step) + metrics_next = {} if validation_mode: - metrics_next, y_preds_next = self.calculate_val_metrics( + metrics_next = self.calculate_val_metrics( y_pred, y, rollout_step, - enable_plot=self.enable_plot, ) - metrics.update(metrics_next) - y_preds.extend(y_preds_next) + yield loss, metrics_next, y_pred + + def _step( + self, + batch: torch.Tensor, + batch_idx: int, + validation_mode: bool = False, + ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: + del batch_idx + loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) + metrics = {} + y_preds = [] + + for loss_next, metrics_next, y_preds_next in self.rollout_step( + batch, + rollout=self.rollout, + training_mode=True, + validation_mode=validation_mode, + ): + loss += loss_next + metrics.update(metrics_next) + y_preds.extend(y_preds_next) - # scale loss loss *= 1.0 / self.rollout return loss, metrics, y_preds @@ -273,26 +422,51 @@ def calculate_val_metrics( y_pred: torch.Tensor, y: torch.Tensor, rollout_step: int, - enable_plot: bool = False, - ) -> tuple[dict, list]: + ) -> tuple[dict, list[torch.Tensor]]: + """Calculate metrics on the validation output. + + Parameters + ---------- + y_pred: torch.Tensor + Predicted ensemble + y: torch.Tensor + Ground truth (target). + rollout_step: int + Rollout step + + Returns + ------- + val_metrics, preds: + validation metrics and predictions + """ metrics = {} - y_preds = [] y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) - for mkey, indices in self.metric_ranges_validation.items(): - metrics[f"{mkey}_{rollout_step + 1}"] = self.metrics( - y_pred_postprocessed[..., indices], - y_postprocessed[..., indices], - ) - if enable_plot: - y_preds.append(y_pred) - return metrics, y_preds + for metric in self.metrics: + metric_name = getattr(metric, "name", metric.__class__.__name__.lower()) + + if not isinstance(metric, BaseWeightedLoss): + # If not a weighted loss, we cannot feature scale, so call normally + metrics[f"{metric_name}/{rollout_step + 1}"] = metric( + y_pred_postprocessed, + y_postprocessed, + ) + continue + + for mkey, indices in self.val_metric_ranges.items(): + metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( + y_pred_postprocessed[..., indices], + y_postprocessed[..., indices], + scalar_indices=[..., indices], + ) + + return metrics def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: train_loss, _, _ = self._step(batch, batch_idx) self.log( - "train_wmse", + f"train_{getattr(self.loss, 'name', self.loss.__class__.__name__.lower())}", train_loss, on_epoch=True, on_step=True, @@ -335,7 +509,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: with torch.no_grad(): val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True) self.log( - "val_wmse", + f"val_{getattr(self.loss, 'name', self.loss.__class__.__name__.lower())}", val_loss, on_epoch=True, on_step=True, diff --git a/src/anemoi/training/train/profiler.py b/src/anemoi/training/train/profiler.py new file mode 100644 index 00000000..40dd50e7 --- /dev/null +++ b/src/anemoi/training/train/profiler.py @@ -0,0 +1,351 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +import os +import warnings +from datetime import datetime +from datetime import timezone +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING + +import hydra +import pandas as pd +from pytorch_lightning.utilities import rank_zero_only +from rich.console import Console + +if TYPE_CHECKING: + from anemoi.training.data.datamodule import AnemoiDatasetsDataModule + from pytorch_lightning.loggers.logger import Logger + from omegaconf import DictConfig + import pytorch_lightning as pl + +from anemoi.training.diagnostics.profilers import BenchmarkProfiler +from anemoi.training.diagnostics.profilers import ProfilerProgressBar +from anemoi.training.train.train import AnemoiTrainer + +LOGGER = logging.getLogger(__name__) +console = Console(record=True, width=200) + + +class AnemoiProfiler(AnemoiTrainer): + """Profiling for Anemoi.""" + + def __init__(self, config: DictConfig) -> None: + super().__init__(config) + + def print_report(self, title: str, dataframe: pd.DataFrame, color: str = "white", emoji: str = "") -> None: + if title == "Model Summary": + console.print(f"[bold {color}]{title}[/bold {color}]", f":{emoji}:") + console.print(dataframe, end="\n\n") + else: + console.print(f"[bold {color}]{title}[/bold {color}]", f":{emoji}:") + console.print(dataframe.to_markdown(headers="keys", tablefmt="psql"), end="\n\n") + + @staticmethod + def print_title() -> None: + console.print("[bold magenta] Benchmark Profiler Summary [/bold magenta]!", ":book:") + + @staticmethod + def print_metadata() -> None: + console.print(f"[bold blue] SLURM NODE(s) {os.environ['HOST']} [/bold blue]!") + console.print(f"[bold blue] SLURM JOB ID {os.environ['SLURM_JOB_ID']} [/bold blue]!") + console.print(f"[bold blue] TIMESTAMP {datetime.now(timezone.utc).strftime('%d/%m/%Y %H:%M:%S')} [/bold blue]!") + + @rank_zero_only + def print_benchmark_profiler_report( + self, + speed_metrics_df: pd.DataFrame | None = None, + time_metrics_df: pd.DataFrame | None = None, + memory_metrics_df: pd.DataFrame | None = None, + system_metrics_df: pd.DataFrame | None = None, + model_summary: str | None = None, + ) -> None: + self.print_title() + self.print_metadata() + + if time_metrics_df is not None: + warnings.warn( + "INFO: Time Report metrics represent single-node metrics (not multi-node aggregated metrics)", + ) + warnings.warn("INFO: Metrics with a * symbol, represent the value after aggregating all steps") + self.print_report("Time Profiling", time_metrics_df, color="green", emoji="alarm_clock") + + if speed_metrics_df is not None: + warnings.warn( + "INFO: Speed Report metrics are single-node metrics (not multi-node aggregated metrics)", + ) + self.print_report("Speed Profiling", speed_metrics_df, color="yellow", emoji="racing_car") + + if memory_metrics_df is not None: + warnings.warn("INFO: Memory Report metrics represent metrics aggregated across all nodes") + self.print_report("Memory Profiling", memory_metrics_df, color="purple", emoji="floppy_disk") + + if system_metrics_df is not None: + self.print_report("System Profiling", system_metrics_df, color="Red", emoji="desktop_computer") + + if model_summary is not None: + self.print_report("Model Summary", model_summary, color="Orange", emoji="robot") + + @staticmethod + def write_benchmark_profiler_report() -> None: + console.save_html("report.html") + + @staticmethod + def to_df(sample_dict: dict[str, float], precision: str = ".5") -> pd.DataFrame: + df = pd.DataFrame(sample_dict.items()) + df.columns = ["metric", "value"] + df.value = df.value.apply(lambda x: f"%{precision}f" % x) + return df + + @cached_property + def speed_profile(self) -> None: + """Speed profiler Report. + + Get speed metrics from Progress Bar for training and validation. + """ + if self.config.diagnostics.benchmark_profiler.speed.enabled: + # Find the first ProfilerProgressBar callback. + for callback in self.callbacks: + if isinstance(callback, ProfilerProgressBar): + return self.profiler.get_speed_profiler_df(callback) + else: + error_msg = "No ProfilerProgressBar callback found." + raise ValueError(error_msg) + else: + return None + + def _get_logger(self) -> dict[str, Logger]: + if (self.config.diagnostics.log.wandb.enabled) and (not self.config.diagnostics.log.wandb.offline): + logger_info = {"logger_name": "wandb", "logger": self.wandb_logger} + elif self.config.diagnostics.log.tensorboard.enabled: + logger_info = {"logger_name": "tensorboard", "logger": self.tensorboard_logger} + elif self.config.diagnostics.log.mlflow.enabled: + logger_info = {"logger_name": "mlflow", "logger": self.mlflow_logger} + else: + LOGGER.warning("No logger enabled for system profiler") + logger_info = None + return logger_info + + @cached_property + def system_profile(self) -> None: + """System Profiler Report.""" + if self.config.diagnostics.benchmark_profiler.system.enabled: + logger_info = self._get_logger() + if logger_info: + return self.profiler.get_system_profiler_df( + logger_name=logger_info["logger_name"], + logger=logger_info["logger"], + ) + LOGGER.warning("System Profiler Report is not available") + return None + return None + + @cached_property + def memory_profile(self) -> None: + """Memory Profiler Report.""" + if self.config.diagnostics.benchmark_profiler.memory.enabled: + return self.profiler.get_memory_profiler_df() + return None + + @cached_property + def time_profile(self) -> None: + """Time Profiler Report.""" + if self.config.diagnostics.benchmark_profiler.time.enabled: + return self.profiler.get_time_profiler_df() + return None + + @cached_property + def model_summary(self) -> str: + if self.config.diagnostics.benchmark_profiler.model_summary.enabled: + if self.config.hardware.num_gpus_per_model > 1: + LOGGER.warning("Model Summary is not supported when using model sharding") + self.config.diagnostics.benchmark_profiler.model_summary.enabled = False + return None + model = self.model + example_input_array = self.example_input_array + return self.profiler.get_model_summary(model=model, example_input_array=example_input_array) + return None + + @rank_zero_only + def export_to_logger(self) -> None: + if (self.config.diagnostics.log.wandb.enabled) and (not self.config.diagnostics.log.wandb.offline): + self.to_wandb() + + elif self.config.diagnostics.log.mlflow.enabled: + self.to_mlflow() + + @rank_zero_only + def report(self) -> str: + """Print report to console.""" + LOGGER.info("Generating Profiler reports") + self.print_benchmark_profiler_report( + memory_metrics_df=self.memory_profile, + time_metrics_df=self.time_profile, + speed_metrics_df=self.speed_profile, # speed profile needs to be generated after time and memory reports + system_metrics_df=self.system_profile, + model_summary=self.model_summary, + ) + + def _get_extra_files(self) -> None: + extra_files = [] + extra_files.extend(self.profiler.dirpath.glob("*.pickle")) + # These trace files are too big to push to MLFlow so + # we won't push them as artifacts extra_files.extend(self.profiler.dirpath.glob("*.json")) + return extra_files + + def _log_reports_to_mlflow(self, run_id: str, data: pd.DataFrame, artifact_file: str, report_fname: str) -> None: + self.mlflow_logger.experiment.log_table( + run_id=run_id, + data=data, + artifact_file=artifact_file, + ) + + self.mlflow_logger.experiment.log_artifact(run_id, report_fname) + + @rank_zero_only + def to_mlflow(self) -> None: + """Log report into MLFlow.""" + LOGGER.info("logging to MLFlow Profiler report") + self.write_benchmark_profiler_report() + # check this https://stackoverflow.com/questions/71151054/how-to-log- d da-table-of-metrics-into-mlflow + + run_id = self.mlflow_logger.run_id + if self.config.diagnostics.benchmark_profiler.system.enabled: + self._log_reports_to_mlflow( + run_id=run_id, + data=self.system_profile, + artifact_file="system_metrics_report.json", + report_fname=self.profiler.system_report_fname, + ) + + if self.config.diagnostics.benchmark_profiler.time.enabled: + self._log_reports_to_mlflow( + run_id=run_id, + data=self.time_profile, + artifact_file="time_metrics_reports.json", + report_fname=self.profiler.time_report_fname, + ) + + if self.config.diagnostics.benchmark_profiler.speed.enabled: + self._log_reports_to_mlflow( + run_id=run_id, + data=self.speed_profile, + artifact_file="speed_metrics_reports.json", + report_fname=self.profiler.speed_report_fname, + ) + + if self.config.diagnostics.benchmark_profiler.memory.enabled: + self._log_reports_to_mlflow( + run_id=run_id, + data=self.memory_profile, + artifact_file="memory_metrics_reports.json", + report_fname=self.profiler.memory_report_fname, + ) + + extra_files = self._get_extra_files() + for file in extra_files: + artifact_path = self.profiler.dirpath / file + if artifact_path.is_file(): + self.mlflow_logger.experiment.log_artifact(run_id, artifact_path) + + if self.config.diagnostics.benchmark_profiler.model_summary.enabled: + self.mlflow_logger.experiment.log_artifact(run_id, self.profiler.model_summary_fname) + + @rank_zero_only + def to_wandb(self) -> None: + """Log report into W&B.""" + LOGGER.info("logging to W&B Profiler report") + self.write_benchmark_profiler_report() + import wandb + from pytorch_lightning.loggers.wandb import WandbLogger + + logger = WandbLogger( + project=self.run_dict["project"], + entity=self.run_dict["entity"], + id=self.run_dict["id"], + offline=self.config.diagnostics.log.wandb.offline, + resume=self.run_dict["id"], + ) + + logger.experiment.log({"speed_metrics_report": wandb.Table(dataframe=self.speed_profile)}) + logger.experiment.log({"memory_metrics_report": wandb.Table(dataframe=self.system_profile)}) + logger.experiment.log({"time_metrics_report": wandb.Table(dataframe=self.time_profile)}) + logger.experiment.log({"memory_metrics_report": wandb.Table(dataframe=self.memory_profile)}) + logger.experiment.log({"model_summary_report": wandb.Table(dataframe=self.model_summary)}) + with Path("report.html").open("w") as f: + logger.experiment.log({"reports_benchmark_profiler": wandb.Html(f)}) + logger.experiment.finish() + + @cached_property + def callbacks(self) -> list[pl.callbacks.Callback]: + callbacks = super().callbacks + callbacks.append(ProfilerProgressBar()) + if self.config.diagnostics.benchmark_profiler.snapshot.enabled: + from anemoi.training.diagnostics.callbacks.profiler import MemorySnapshotRecorder + from anemoi.training.diagnostics.profilers import check_torch_version + + available = check_torch_version() + + if available: # if torch is below 2.1.0, the callback will not be added + callbacks.append(MemorySnapshotRecorder(self.config)) + return callbacks + + @cached_property + def datamodule(self) -> AnemoiDatasetsDataModule: + datamodule = super().datamodule + # to generate a model summary with shapes we need a sample input array + batch = next(iter(datamodule.train_dataloader())) + self.example_input_array = batch[ + :, + 0 : self.config.training.multistep_input, + ..., + self.data_indices.data.input.full, + ] + return datamodule + + @cached_property + def profiler(self) -> BenchmarkProfiler: + return BenchmarkProfiler(self.config) + + def _update_paths(self) -> None: + """Update the paths in the configuration.""" + super()._update_paths() + + if self.run_id: # when using mlflow only rank0 will have a run_id except when resuming runs + # Multi-gpu new runs or forked runs - only rank 0 + # Multi-gpu resumed runs - all ranks + self.config.hardware.paths.profiler = Path(self.config.hardware.paths.profiler, self.run_id) + elif self.config.training.fork_run_id: + parent_run = self.config.training.fork_run_id + self.config.hardware.paths.profiler = Path(self.config.hardware.paths.profiler, parent_run) + LOGGER.info("Profiler path: %s", self.config.hardware.paths.profiler) + + def _close_logger(self) -> None: + if (self.config.diagnostics.log.wandb.enabled) and (not self.config.diagnostics.log.wandb.offline): + # We need to close the W&B logger to be able to read the System Metrics + self.wandb_logger.experiment.finish() + + def profile(self) -> None: + """Profile the model.""" + self.train() + self.report() + self.export_to_logger() + + +@hydra.main(version_base=None, config_path="../config", config_name="config") +def main(config: DictConfig) -> None: + AnemoiProfiler(config).profile() + + +if __name__ == "__main__": + main() diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index b772eb2a..553114f5 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -302,6 +302,16 @@ def _log_information(self) -> None: LOGGER.debug("Effective learning rate: %.3e", total_number_of_model_instances * self.config.training.lr.rate) LOGGER.debug("Rollout window length: %d", self.config.training.rollout.start) + if self.config.training.max_epochs is not None and self.config.training.max_steps not in (None, -1): + LOGGER.info( + "Training limits: max_epochs=%d, max_steps=%d. " + "Training will stop when either limit is reached first. " + "Learning rate scheduler will run for %d steps.", + self.config.training.max_epochs, + self.config.training.max_steps, + self.config.training.lr.iterations, + ) + def _get_server2server_lineage(self) -> None: """Get the server2server lineage.""" self.parent_run_server2server = None @@ -350,12 +360,13 @@ def train(self) -> None: num_nodes=self.config.hardware.num_nodes, precision=self.config.training.precision, max_epochs=self.config.training.max_epochs, + max_steps=self.config.training.max_steps or -1, logger=self.loggers, log_every_n_steps=self.config.diagnostics.log.interval, # run a fixed no of batches per epoch (helpful when debugging) limit_train_batches=self.config.dataloader.limit_batches.training, limit_val_batches=self.config.dataloader.limit_batches.validation, - num_sanity_val_steps=4, + num_sanity_val_steps=self.config.training.num_sanity_val_steps, accumulate_grad_batches=self.config.training.accum_grad_batches, gradient_clip_val=self.config.training.gradient_clip.val, gradient_clip_algorithm=self.config.training.gradient_clip.algorithm, diff --git a/src/anemoi/training/utils/__init__.py b/src/anemoi/training/utils/__init__.py index 33d7fa0a..c167afa2 100644 --- a/src/anemoi/training/utils/__init__.py +++ b/src/anemoi/training/utils/__init__.py @@ -1,8 +1,8 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# diff --git a/src/hydra_plugins/anemoi_searchpath/__init__.py b/src/hydra_plugins/anemoi_searchpath/__init__.py index 33d7fa0a..c167afa2 100644 --- a/src/hydra_plugins/anemoi_searchpath/__init__.py +++ b/src/hydra_plugins/anemoi_searchpath/__init__.py @@ -1,8 +1,8 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# diff --git a/tests/diagnostics/mlflow/test_expansion.py b/tests/diagnostics/mlflow/test_expansion.py new file mode 100644 index 00000000..67dcc54e --- /dev/null +++ b/tests/diagnostics/mlflow/test_expansion.py @@ -0,0 +1,46 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from anemoi.training.diagnostics.mlflow.utils import expand_iterables + + +def test_expand_iterables_single_iterable() -> None: + # Test case with a single iterable + dictionary = {"a": ["a", "b", "c"]} + expanded = expand_iterables(dictionary) + assert expanded == {"a.0": "a", "a.1": "b", "a.2": "c", "a.all": ["a", "b", "c"], "a.length": 3} + + +def test_expand_iterables_size_threshold() -> None: + # Test case with a single iterable + dictionary = {"a": ["a", "b", "c"]} + expanded = expand_iterables(dictionary, size_threshold=100) + assert expanded == dictionary + + +def test_expand_iterables_with_nested_dict() -> None: + dictionary = {"a": {"b": ["a", "b", "c"]}} + expanded = expand_iterables(dictionary) + assert expanded == {"a": {"b.0": "a", "b.1": "b", "b.2": "c", "b.all": ["a", "b", "c"], "b.length": 3}} + + +def test_expand_iterables_with_nested_dict_thresholded() -> None: + dictionary = {"a": {"b": ["a", "b", "c"]}, "c": ["d"]} + expanded = expand_iterables(dictionary, size_threshold=5) + assert expanded == {"a": {"b.0": "a", "b.1": "b", "b.2": "c", "b.all": ["a", "b", "c"], "b.length": 3}, "c": ["d"]} + + +def test_expand_iterables_with_nested_list() -> None: + dictionary = {"a": [[0, 1, 2], "b", "c"]} + expanded = expand_iterables(dictionary) + assert expanded == { + "a.0": {0: 0, 1: 1, 2: 2}, + "a.1": "b", + "a.2": "c", + "a.all": [[0, 1, 2], "b", "c"], + "a.length": 3, + } diff --git a/tests/diagnostics/test_callbacks.py b/tests/diagnostics/test_callbacks.py new file mode 100644 index 00000000..a61b19f1 --- /dev/null +++ b/tests/diagnostics/test_callbacks.py @@ -0,0 +1,76 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# ruff: noqa: ANN001, ANN201 + +import omegaconf +import yaml + +from anemoi.training.diagnostics.callbacks import get_callbacks + +default_config = """ +diagnostics: + callbacks: [] + + plot: + enabled: False + callbacks: [] + + debug: + # this will detect and trace back NaNs / Infs etc. but will slow down training + anomaly_detection: False + + profiler: False + + enable_checkpointing: False + checkpoint: + + log: {} +""" + + +def test_no_extra_callbacks_set(): + # No extra callbacks set + config = omegaconf.OmegaConf.create(yaml.safe_load(default_config)) + callbacks = get_callbacks(config) + assert len(callbacks) == 1 # ParentUUIDCallback + + +def test_add_config_enabled_callback(): + # Add logging callback + config = omegaconf.OmegaConf.create(default_config) + config.diagnostics.callbacks.append({"log": {"mlflow": {"enabled": True}}}) + callbacks = get_callbacks(config) + assert len(callbacks) == 2 + + +def test_add_callback(): + config = omegaconf.OmegaConf.create(default_config) + config.diagnostics.callbacks.append( + {"_target_": "anemoi.training.diagnostics.callbacks.provenance.ParentUUIDCallback"}, + ) + callbacks = get_callbacks(config) + assert len(callbacks) == 2 + + +def test_add_plotting_callback(monkeypatch): + # Add plotting callback + import anemoi.training.diagnostics.callbacks.plot as plot + + class PlotLoss: + def __init__(self, config: omegaconf.DictConfig): + pass + + monkeypatch.setattr(plot, "PlotLoss", PlotLoss) + + config = omegaconf.OmegaConf.create(default_config) + config.diagnostics.plot.enabled = True + config.diagnostics.plot.callbacks = [{"_target_": "anemoi.training.diagnostics.callbacks.plot.PlotLoss"}] + callbacks = get_callbacks(config) + assert len(callbacks) == 2 diff --git a/tests/train/test_loss_function.py b/tests/train/test_loss_function.py new file mode 100644 index 00000000..73d7f246 --- /dev/null +++ b/tests/train/test_loss_function.py @@ -0,0 +1,64 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import torch +from omegaconf import DictConfig + +from anemoi.training.losses.mse import WeightedMSELoss +from anemoi.training.losses.weightedloss import BaseWeightedLoss +from anemoi.training.train.forecaster import GraphForecaster + + +def test_manual_init() -> None: + loss = WeightedMSELoss(torch.ones(1)) + assert loss.node_weights == torch.ones(1) + + +def test_dynamic_init_include() -> None: + loss = GraphForecaster.get_loss_function( + DictConfig({"_target_": "anemoi.training.losses.mse.WeightedMSELoss"}), + node_weights=torch.ones(1), + ) + assert isinstance(loss, BaseWeightedLoss) + assert loss.node_weights == torch.ones(1) + + +def test_dynamic_init_scalar() -> None: + loss = GraphForecaster.get_loss_function( + DictConfig( + { + "_target_": "anemoi.training.losses.mse.WeightedMSELoss", + "scalars": ["test"], + }, + ), + node_weights=torch.ones(1), + scalars={"test": ((0, 1), torch.ones((1, 2)))}, + ) + assert isinstance(loss, BaseWeightedLoss) + + torch.testing.assert_close(loss.node_weights, torch.ones(1)) + assert "test" in loss.scalar + torch.testing.assert_close(loss.scalar.get_scalar(2), torch.ones((1, 2))) + + +def test_dynamic_init_scalar_not_add() -> None: + loss = GraphForecaster.get_loss_function( + DictConfig( + { + "_target_": "anemoi.training.losses.mse.WeightedMSELoss", + "scalars": [], + }, + ), + node_weights=torch.ones(1), + scalars={"test": (-1, torch.ones(2))}, + ) + assert isinstance(loss, BaseWeightedLoss) + torch.testing.assert_close(loss.node_weights, torch.ones(1)) + assert "test" not in loss.scalar diff --git a/tests/train/test_loss_scaling.py b/tests/train/test_loss_scaling.py index 2da5ae00..8dd3772a 100644 --- a/tests/train/test_loss_scaling.py +++ b/tests/train/test_loss_scaling.py @@ -29,7 +29,7 @@ def fake_data(request: SubRequest) -> tuple[DictConfig, IndexCollection]: }, }, "training": { - "loss_scaling": { + "variable_loss_scaling": { "default": 1, "sfc": { "z": 0.1, @@ -128,19 +128,24 @@ def fake_data(request: SubRequest) -> tuple[DictConfig, IndexCollection]: ], indirect=["fake_data"], ) -def test_loss_scaling_vals(fake_data: tuple[DictConfig, IndexCollection], expected_scaling: torch.Tensor) -> None: +def test_variable_loss_scaling_vals( + fake_data: tuple[DictConfig, IndexCollection], + expected_scaling: torch.Tensor, +) -> None: config, data_indices = fake_data - _, _, loss_scaling = GraphForecaster.metrics_loss_scaling(config, data_indices) + variable_loss_scaling = GraphForecaster.get_variable_scaling(config, data_indices) - assert torch.allclose(loss_scaling, expected_scaling) + assert torch.allclose(variable_loss_scaling, expected_scaling) @pytest.mark.parametrize("fake_data", [linear_scaler], indirect=["fake_data"]) def test_metric_range(fake_data: tuple[DictConfig, IndexCollection]) -> None: config, data_indices = fake_data - metric_range, metric_ranges_validation, _ = GraphForecaster.metrics_loss_scaling(config, data_indices) + metric_range, metric_ranges_validation = GraphForecaster.get_val_metric_ranges(config, data_indices) + + del metric_range["all"] expected_metric_range_validation = { "pl_y": [ @@ -161,5 +166,5 @@ def test_metric_range(fake_data: tuple[DictConfig, IndexCollection]) -> None: expected_metric_range["sfc_cos_d"] = [data_indices.internal_model.output.name_to_index["cos_d"]] expected_metric_range["sfc_sin_d"] = [data_indices.internal_model.output.name_to_index["sin_d"]] - assert dict(metric_ranges_validation) == expected_metric_range_validation - assert dict(metric_range) == expected_metric_range + assert metric_ranges_validation == expected_metric_range_validation + assert metric_range == expected_metric_range diff --git a/tests/train/test_scalar.py b/tests/train/test_scalar.py new file mode 100644 index 00000000..9a37e353 --- /dev/null +++ b/tests/train/test_scalar.py @@ -0,0 +1,200 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest +import torch + +from anemoi.training.losses.utils import ScaleTensor + + +def test_scale_contains() -> None: + scale = ScaleTensor(test=(0, 2)) + assert "test" in scale + + +def test_scale_contains_indexing() -> None: + scale = ScaleTensor(test=(0, 2)) + assert 0 in scale + + +def test_scale_tuple_contains_indexing() -> None: + scale = ScaleTensor(test=((0, 1), 2)) + assert (0, 1) in scale + + +def test_scale_tuple_not_contains_indexing() -> None: + scale = ScaleTensor(test=(0, 2)) + assert (0, 1) not in scale + + +def test_scale_contains_subset_indexing() -> None: + scale = ScaleTensor(test=(0, 2), wow=(0, 2)) + assert "test" in scale + scale = scale.subset("wow") + assert "wow" in scale + assert "test" not in scale + + +def test_scale_contains_subset_by_dim_indexing() -> None: + scale = ScaleTensor(test=(0, 2), wow=(1, 2)) + assert "test" in scale + scale = scale.subset_by_dim(1) + assert "wow" in scale + assert "test" not in scale + + +def test_add_existing_scalar() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0]))) + with pytest.raises(ValueError, match=r".*already exists.*"): + scale.add_scalar(0, torch.tensor(3.0), name="test") + + +def test_update_scalar() -> None: + scale = ScaleTensor(test=(0, torch.ones(2))) + scale.update_scalar("test", torch.tensor([3.0])) + torch.testing.assert_close(scale.tensors["test"][1], torch.tensor([3.0])) + + +def test_update_missing_scalar() -> None: + scale = ScaleTensor(test=(0, torch.ones(2))) + with pytest.raises(ValueError, match=r".*not found in scalars.*"): + scale.update_scalar("test_missing", torch.tensor([3.0])) + assert "test" in scale + assert (0,) in scale + + +def test_update_scalar_wrong_dim() -> None: + scale = ScaleTensor(test=(0, torch.ones((2, 3)))) + with pytest.raises(ValueError, match=r".*does not match shape of saved scalar.*"): + scale.update_scalar("test", torch.ones((2, 2))) + assert "test" in scale + assert 0 in scale + + +def test_update_scalar_wrong_dim_allow_override() -> None: + scale = ScaleTensor(test=(0, torch.ones((2, 3)))) + assert scale.update_scalar("test", torch.ones((2, 2)), override=True) is None + + +@pytest.mark.parametrize( + ("scalars", "input_tensor", "output"), + [ + ([[0, torch.Tensor([2])]], torch.tensor([1.0, 2.0, 3.0]), torch.tensor([2.0, 4.0, 6.0])), + ([[0, torch.Tensor([0.5])]], torch.tensor([10.0, 20.0, 30.0]), torch.tensor([5.0, 10.0, 15.0])), + ([[-1, torch.Tensor([0.5])]], torch.tensor([10.0, 20.0, 30.0]), torch.tensor([5.0, 10.0, 15.0])), + ([[0, torch.Tensor([0])]], torch.tensor([5.0, 10.0, 15.0]), torch.tensor([0.0, 0.0, 0.0])), + ( + [[0, torch.Tensor([0.5])], [-1, torch.Tensor([3])]], + torch.tensor([10.0, 20.0, 30.0]), + torch.tensor([15.0, 30.0, 45.0]), + ), + ( + [[0, torch.Tensor([0.5])], [0, torch.Tensor([3])]], + torch.tensor([10.0, 20.0, 30.0]), + torch.tensor([15.0, 30.0, 45.0]), + ), + ], +) +def test_scale_tensor_one_dim( + scalars: list[list[int, torch.Tensor]], + input_tensor: torch.Tensor, + output: torch.Tensor, +) -> None: + + scale = ScaleTensor() + for scalar in scalars: + scale.add_scalar(*scalar) + + torch.testing.assert_close(scale.scale(input_tensor), output) + + +def test_invalid_dim_sizes() -> None: + scalar = ScaleTensor() + scalar.add_scalar(0, torch.ones((5,))) + + with pytest.raises(ValueError, match=r"Validating tensor 'invalid' raised an error."): + scalar.add_scalar(0, torch.ones((15,)), name="invalid") + + +def test_invalid_dim_sizes_negative_indexing() -> None: + scalar = ScaleTensor() + scalar.add_scalar(0, torch.ones((5,))) + scalar.add_scalar(-1, torch.ones((15,)), name="invalid") + + with pytest.raises(ValueError, match=r"Validating tensor 'invalid' raised an error."): + scalar.resolve(1) + + +def test_valid_dim_sizes_negative_indexing() -> None: + scalar = ScaleTensor() + scalar.add_scalar(0, torch.ones((5,))) + scalar.add_scalar(-1, torch.ones((15,))) + + scalar.resolve(2) + + +@pytest.mark.parametrize( + ("scalars", "input_tensor", "output"), + [ + ([[0, 2.0]], torch.ones([4, 6]), torch.ones([4, 6]) * 2), + ([[0, [[1.0, 1.0], [1.0, 2.0]]]], torch.ones((2, 2)), [[1, 1], [1, 2]]), + ([[(0, 1), [[1.0, 1.0], [1.0, 2.0]]]], torch.ones((2, 2)), [[1, 1], [1, 2]]), + ([[(1, 0), [[1.0, 1.0], [1.0, 2.0]]]], torch.ones((2, 2)), [[1, 1], [1, 2]]), + ([[(0, 1), [[1.0, 2.0], [1.0, 1.0]]]], torch.ones((2, 2)), [[1, 2], [1, 1]]), + ([[(1, 0), [[1.0, 2.0], [1.0, 1.0]]]], torch.ones((2, 2)), [[1, 1], [2, 1]]), + ], +) +def test_scale_tensor_two_dim( + scalars: list[list[int, torch.Tensor]], + input_tensor: torch.Tensor, + output: torch.Tensor, +) -> None: + + scale = ScaleTensor() + for scalar in scalars: + scale.add_scalar(*scalar) + + if not isinstance(input_tensor, torch.Tensor): + input_tensor = torch.tensor(input_tensor, dtype=torch.float32) + if not isinstance(output, torch.Tensor): + output = torch.tensor(output, dtype=torch.float32) + + torch.testing.assert_close(scale.scale(input_tensor), output) + + +def test_scalar_subset() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(0, torch.tensor([3.0]))) + subset = scale.subset("test") + assert "test" in subset + assert "wow" not in subset + assert 0 in subset + + +def test_scalar_subset_without() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(0, torch.tensor([3.0]))) + subset = scale.without("test") + assert "test" not in subset + assert "wow" in subset + assert 0 in subset + + +def test_scalar_subset_by_dim() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(1, torch.tensor([3.0]))) + subset = scale.subset_by_dim(0) + assert "test" in subset + assert "wow" not in subset + assert 0 in subset + + +def test_scalar_subset_by_dim_without() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(1, torch.tensor([3.0]))) + subset = scale.without_by_dim(0) + assert "test" not in subset + assert "wow" in subset + assert 0 not in subset