diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index db316803..1db47a1b 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -15,7 +15,12 @@ jobs: skip-hooks: "no-commit-to-branch" checks: + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 + with: + python-version: ${{ matrix.python-version }} deploy: needs: [checks, quality] diff --git a/.github/workflows/python-pull-request.yml b/.github/workflows/python-pull-request.yml index 3488f55c..cef24795 100644 --- a/.github/workflows/python-pull-request.yml +++ b/.github/workflows/python-pull-request.yml @@ -16,4 +16,9 @@ jobs: skip-hooks: "no-commit-to-branch" checks: + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2 + with: + python-version: ${{ matrix.python-version }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 231a0fa6..9e14700f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,18 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.2.0...HEAD) ### Added +- Mlflow-sync to include new tag for server to server syncing [#83] (https://github.com/ecmwf/anemoi-training/pull/83) +- Mlflow-sync to include functionality to resume and fork server2server runs [#83] (https://github.com/ecmwf/anemoi-training/pull/83) +- Rollout training for Limited Area Models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79) +- Feature: New `Boolean1DMask` class. Enables rollout training for limited area models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79) + ### Fixed +- Mlflow-sync to handle creation of new experiments in the remote server [#83] (https://github.com/ecmwf/anemoi-training/pull/83) +- Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99] (https://github.com/ecmwf/anemoi-training/pull/99) +- ci: fix pyshtools install error (#100) https://github.com/ecmwf/anemoi-training/pull/100 + ### Changed +- Update copyright notice ## [0.2.0 - Feature release](https://github.com/ecmwf/anemoi-training/compare/0.1.0...0.2.0) - 2024-10-16 diff --git a/README.md b/README.md index 5ee7892f..e51c7697 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ $ pip install anemoi-training ## License ``` -Copyright 2022, European Centre for Medium Range Weather Forecasts. +Copyright 2024, Anemoi contributors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/src/anemoi/training/commands/checkpoint.py b/src/anemoi/training/commands/checkpoint.py index af1aa296..82c95e9e 100644 --- a/src/anemoi/training/commands/checkpoint.py +++ b/src/anemoi/training/commands/checkpoint.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import argparse import logging diff --git a/src/anemoi/training/commands/config.py b/src/anemoi/training/commands/config.py index e79602d6..221d76dd 100644 --- a/src/anemoi/training/commands/config.py +++ b/src/anemoi/training/commands/config.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import importlib.resources as pkg_resources diff --git a/src/anemoi/training/commands/mlflow.py b/src/anemoi/training/commands/mlflow.py index 8be8183d..545b8a10 100644 --- a/src/anemoi/training/commands/mlflow.py +++ b/src/anemoi/training/commands/mlflow.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import argparse from anemoi.training.commands import Command diff --git a/src/anemoi/training/commands/train.py b/src/anemoi/training/commands/train.py index 705b4783..44ce186f 100644 --- a/src/anemoi/training/commands/train.py +++ b/src/anemoi/training/commands/train.py @@ -1,11 +1,13 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import logging diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index b714c965..f64a3091 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging import os from functools import cached_property diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index e2aa12bd..9e368f9c 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -1,9 +1,12 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import logging diff --git a/src/anemoi/training/data/scaling.py b/src/anemoi/training/data/scaling.py index 74ba9c23..83419a88 100644 --- a/src/anemoi/training/data/scaling.py +++ b/src/anemoi/training/data/scaling.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging from abc import ABC from abc import abstractmethod diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f2195b5f..cf085eab 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -139,6 +139,13 @@ def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: st if self._executor is not None: self._executor.shutdown(wait=True) + def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor: + if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None: + # Fill with NaNs values where the mask is False + data[:, :, ~pl_module.output_mask, :] = np.nan + + return data + @abstractmethod @rank_zero_only def _plot( @@ -682,12 +689,16 @@ def _plot( ..., pl_module.data_indices.internal_data.output.full, ].cpu() - data = self.post_processors(input_tensor).numpy() + data = self.post_processors(input_tensor) output_tensor = self.post_processors( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, - ).numpy() + ) + + output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) + data = data.numpy() for rollout_step in range(pl_module.rollout): fig = plot_predicted_multilevel_flat_sample( @@ -776,11 +787,15 @@ def _plot( ..., pl_module.data_indices.internal_data.output.full, ].cpu() - data = self.post_processors(input_tensor).numpy() + data = self.post_processors(input_tensor) output_tensor = self.post_processors( torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), in_place=False, - ).numpy() + ) + + output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) + data = data.numpy() for rollout_step in range(pl_module.rollout): if self.config.diagnostics.plot.parameters_histogram is not None: diff --git a/src/anemoi/training/diagnostics/logger.py b/src/anemoi/training/diagnostics/logger.py index 599dad36..4e4a35c1 100644 --- a/src/anemoi/training/diagnostics/logger.py +++ b/src/anemoi/training/diagnostics/logger.py @@ -1,9 +1,12 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import logging @@ -25,7 +28,6 @@ def get_mlflow_logger(config: DictConfig) -> None: return None from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger - from anemoi.training.diagnostics.mlflow.logger import get_mlflow_run_params resumed = config.training.run_id is not None forked = config.training.fork_run_id is not None @@ -39,7 +41,6 @@ def get_mlflow_logger(config: DictConfig) -> None: tracking_uri = save_dir # create directory if it does not exist Path(config.hardware.paths.logs.mlflow).mkdir(parents=True, exist_ok=True) - run_id, run_name, tags = get_mlflow_run_params(config, tracking_uri) log_hyperparams = True if resumed and not config.diagnostics.log.mlflow.on_resume_create_child: @@ -53,19 +54,22 @@ def get_mlflow_logger(config: DictConfig) -> None: ) log_hyperparams = False + LOGGER.info("AnemoiMLFlow logging to %s", tracking_uri) logger = AnemoiMLflowLogger( experiment_name=config.diagnostics.log.mlflow.experiment_name, + project_name=config.diagnostics.log.mlflow.project_name, tracking_uri=tracking_uri, save_dir=save_dir, - run_name=run_name, - run_id=run_id, + run_name=config.diagnostics.log.mlflow.run_name, + run_id=config.training.run_id, + fork_run_id=config.training.fork_run_id, log_model=config.diagnostics.log.mlflow.log_model, offline=offline, - tags=tags, resumed=resumed, forked=forked, log_hyperparams=log_hyperparams, authentication=config.diagnostics.log.mlflow.authentication, + on_resume_create_child=config.diagnostics.log.mlflow.on_resume_create_child, ) config_params = OmegaConf.to_container(config, resolve=True) diff --git a/src/anemoi/training/diagnostics/maps.py b/src/anemoi/training/diagnostics/maps.py index 265d814b..338a9059 100644 --- a/src/anemoi/training/diagnostics/maps.py +++ b/src/anemoi/training/diagnostics/maps.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import copy import json import logging diff --git a/src/anemoi/training/diagnostics/mlflow/auth.py b/src/anemoi/training/diagnostics/mlflow/auth.py index 7189ea21..144a967e 100644 --- a/src/anemoi/training/diagnostics/mlflow/auth.py +++ b/src/anemoi/training/diagnostics/mlflow/auth.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import logging diff --git a/src/anemoi/training/diagnostics/mlflow/client.py b/src/anemoi/training/diagnostics/mlflow/client.py index 5c49f929..2d97cb4b 100644 --- a/src/anemoi/training/diagnostics/mlflow/client.py +++ b/src/anemoi/training/diagnostics/mlflow/client.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations from typing import Any diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 2abebb1c..183e7a0d 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import io @@ -33,55 +36,11 @@ if TYPE_CHECKING: from argparse import Namespace - from omegaconf import OmegaConf + import mlflow LOGGER = logging.getLogger(__name__) -def get_mlflow_run_params(config: OmegaConf, tracking_uri: str) -> tuple[str | None, str, dict[str, Any]]: - run_id = None - tags = {"projectName": config.diagnostics.log.mlflow.project_name} - # create a tag with the command used to run the script - command = os.environ.get("ANEMOI_TRAINING_CMD", sys.argv[0]) - tags["command"] = command.split("/")[-1] # get the python script name - tags["mlflow.source.name"] = command - if len(sys.argv) > 1: - # add the arguments to the command tag - tags["command"] = tags["command"] + " " + " ".join(sys.argv[1:]) - if config.training.run_id or config.training.fork_run_id: - "Either run_id or fork_run_id must be provided to resume a run." - - import mlflow - - if config.diagnostics.log.mlflow.authentication and not config.diagnostics.log.mlflow.offline: - TokenAuth(tracking_uri).authenticate() - - mlflow_client = mlflow.MlflowClient(tracking_uri) - - if config.training.run_id and config.diagnostics.log.mlflow.on_resume_create_child: - parent_run_id = config.training.run_id # parent_run_id - run_name = mlflow_client.get_run(parent_run_id).info.run_name - tags["mlflow.parentRunId"] = parent_run_id - tags["resumedRun"] = "True" # tags can't take boolean values - elif config.training.run_id and not config.diagnostics.log.mlflow.on_resume_create_child: - run_id = config.training.run_id - run_name = mlflow_client.get_run(run_id).info.run_name - mlflow_client.update_run(run_id=run_id, status="RUNNING") - tags["resumedRun"] = "True" - else: - parent_run_id = config.training.fork_run_id - tags["forkedRun"] = "True" - tags["forkedRunId"] = parent_run_id - - if config.diagnostics.log.mlflow.run_name: - run_name = config.diagnostics.log.mlflow.run_name - else: - import uuid - - run_name = f"{uuid.uuid4()!s}" - return run_id, run_name, tags - - class LogsMonitor: """Class for logging terminal output. @@ -90,7 +49,7 @@ class LogsMonitor: Note: If there is an error, the terminal output logging ends before the error message is printed into the log file. In order for the user to see the error message, the user must look at the slurm output file. - We provide the SLRM job id in the very beginning of the log file and print the final status of the run in the end. + We provide the SLURM job id in the very beginning of the log file and print the final status of the run in the end. Parameters ---------- @@ -191,7 +150,7 @@ def start(self) -> None: self._buffer_registry[id(self)] = self._io_buffer # Start thread to asynchronously collect logs self._th_collector.start() - LOGGER.info("Termial Log Path: %s", self.file_save_path) + LOGGER.info("Terminal Log Path: %s", self.file_save_path) if os.getenv("SLURM_JOB_ID"): LOGGER.info("SLURM job id: %s", os.getenv("SLURM_JOB_ID")) @@ -288,18 +247,20 @@ class AnemoiMLflowLogger(MLFlowLogger): def __init__( self, experiment_name: str = "lightning_logs", + project_name: str = "anemoi", run_name: str | None = None, tracking_uri: str | None = os.getenv("MLFLOW_TRACKING_URI"), - tags: dict[str, Any] | None = None, save_dir: str | None = "./mlruns", log_model: Literal[True, False, "all"] = False, prefix: str = "", resumed: bool | None = False, forked: bool | None = False, run_id: str | None = None, + fork_run_id: str | None = None, offline: bool | None = False, authentication: bool | None = None, log_hyperparams: bool | None = True, + on_resume_create_child: bool | None = True, ) -> None: """Initialize the AnemoiMLflowLogger. @@ -307,12 +268,12 @@ def __init__( ---------- experiment_name : str, optional Name of experiment, by default "lightning_logs" + project_name : str, optional + Name of the project, by default "anemoi" run_name : str | None, optional Name of run, by default None tracking_uri : str | None, optional Tracking URI of server, by default os.getenv("MLFLOW_TRACKING_URI") - tags : dict[str, Any] | None, optional - Tags to apply, by default None save_dir : str | None, optional Directory to save logs to, by default "./mlruns" log_model : Literal[True, False, "all"], optional @@ -325,31 +286,28 @@ def __init__( Whether the run was forked or not, by default False run_id : str | None, optional Run id of current run, by default None + fork_run_id : str | None, optional + Fork Run id from parent run, by default None offline : bool | None, optional Whether to run offline or not, by default False authentication : bool | None, optional Whether to authenticate with server or not, by default None log_hyperparams : bool | None, optional Whether to log hyperparameters, by default True - + on_resume_create_child: bool | None, optional + Whether to create a child run when resuming a run, by default False """ - if offline: - # OFFLINE - When we run offline we can pass a save_dir pointing to a local path - tracking_uri = None - - else: - # ONLINE - When we pass a tracking_uri to mlflow then it will ignore the - # saving dir and save all artifacts/metrics to the remote server database - save_dir = None - self._resumed = resumed self._forked = forked self._flag_log_hparams = log_hyperparams - if rank_zero_only.rank == 0: - enabled = authentication and not offline - self.auth = TokenAuth(tracking_uri, enabled=enabled) + self._fork_run_server2server = None + self._parent_run_server2server = None + + enabled = authentication and not offline + self.auth = TokenAuth(tracking_uri, enabled=enabled) + if rank_zero_only.rank == 0: if offline: LOGGER.info("MLflow is logging offline.") else: @@ -357,6 +315,24 @@ def __init__( self.auth.authenticate() health_check(tracking_uri) + run_id, run_name, tags = self._get_mlflow_run_params( + project_name=project_name, + run_name=run_name, + config_run_id=run_id, + fork_run_id=fork_run_id, + tracking_uri=tracking_uri, + on_resume_create_child=on_resume_create_child, + ) + # Before creating the run we need to overwrite the tracking_uri and save_dir if offline + if offline: + # OFFLINE - When we run offline we can pass a save_dir pointing to a local path + tracking_uri = None + + else: + # ONLINE - When we pass a tracking_uri to mlflow then it will ignore the + # saving dir and save all artifacts/metrics to the remote server database + save_dir = None + super().__init__( experiment_name=experiment_name, run_name=run_name, @@ -368,6 +344,84 @@ def __init__( run_id=run_id, ) + def _check_server2server_lineage(self, run: mlflow.entities.Run) -> bool: + """Address lineage and metadata for server2server runs. + + Those are runs that have been sync from one remote server to another + """ + server2server = run.data.tags.get("server2server", "False") == "True" + LOGGER.info("Server2Server: %s", server2server) + if server2server: + parent_run_across_servers = run.data.params.get( + "metadata.offline_run_id", + run.data.params.get("metadata.server2server_run_id"), + ) + if self._forked: + # if we want to fork a resume run we need to set the parent_run_across_servers + # but just to restore the checkpoint + self._fork_run_server2server = parent_run_across_servers + else: + self._parent_run_server2server = parent_run_across_servers + + def _get_mlflow_run_params( + self, + project_name: str, + run_name: str, + config_run_id: str, + fork_run_id: str, + tracking_uri: str, + on_resume_create_child: bool, + ) -> tuple[str | None, str, dict[str, Any]]: + + run_id = None + tags = {"projectName": project_name} + + # create a tag with the command used to run the script + command = os.environ.get("ANEMOI_TRAINING_CMD", sys.argv[0]) + tags["command"] = command.split("/")[-1] # get the python script name + tags["mlflow.source.name"] = command + if len(sys.argv) > 1: + # add the arguments to the command tag + tags["command"] = tags["command"] + " " + " ".join(sys.argv[1:]) + + if config_run_id or fork_run_id: + "Either run_id or fork_run_id must be provided to resume a run." + import mlflow + + self.auth.authenticate() + mlflow_client = mlflow.MlflowClient(tracking_uri) + + if config_run_id and on_resume_create_child: + parent_run_id = config_run_id # parent_run_id + parent_run = mlflow_client.get_run(parent_run_id) + run_name = parent_run.info.run_name + self._check_server2server_lineage(parent_run) + tags["mlflow.parentRunId"] = parent_run_id + tags["resumedRun"] = "True" # tags can't take boolean values + elif config_run_id and not on_resume_create_child: + run_id = config_run_id + run = mlflow_client.get_run(run_id) + run_name = run.info.run_name + self._check_server2server_lineage(run) + mlflow_client.update_run(run_id=run_id, status="RUNNING") + tags["resumedRun"] = "True" + else: + parent_run_id = fork_run_id + tags["forkedRun"] = "True" + tags["forkedRunId"] = parent_run_id + run = mlflow_client.get_run(parent_run_id) + self._check_server2server_lineage(run) + + if not run_name: + import uuid + + run_name = f"{uuid.uuid4()!s}" + + if os.getenv("SLURM_JOB_ID"): + tags["SLURM_JOB_ID"] = os.getenv("SLURM_JOB_ID") + + return run_id, run_name, tags + @property def experiment(self) -> MLFlowLogger.experiment: if rank_zero_only.rank == 0: diff --git a/src/anemoi/training/diagnostics/mlflow/utils.py b/src/anemoi/training/diagnostics/mlflow/utils.py index 9e0ae4e7..89f6e002 100644 --- a/src/anemoi/training/diagnostics/mlflow/utils.py +++ b/src/anemoi/training/diagnostics/mlflow/utils.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import os import requests diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index b2004cf4..7b4ba711 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import logging diff --git a/src/anemoi/training/distributed/strategy.py b/src/anemoi/training/distributed/strategy.py index c15828ca..c6509795 100644 --- a/src/anemoi/training/distributed/strategy.py +++ b/src/anemoi/training/distributed/strategy.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging import os diff --git a/src/anemoi/training/losses/mse.py b/src/anemoi/training/losses/mse.py index b72a7aea..88ad0d0b 100644 --- a/src/anemoi/training/losses/mse.py +++ b/src/anemoi/training/losses/mse.py @@ -1,11 +1,12 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# + from __future__ import annotations diff --git a/src/anemoi/training/losses/utils.py b/src/anemoi/training/losses/utils.py index 9a866a0a..5ddef3d6 100644 --- a/src/anemoi/training/losses/utils.py +++ b/src/anemoi/training/losses/utils.py @@ -1,11 +1,12 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# + from __future__ import annotations diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index ff1acfd7..36f6fa9a 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -1,11 +1,12 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# + import logging import math @@ -31,6 +32,8 @@ from anemoi.training.losses.mse import WeightedMSELoss from anemoi.training.losses.utils import grad_scaler from anemoi.training.utils.jsonify import map_config_to_primitives +from anemoi.training.utils.masks import Boolean1DMask +from anemoi.training.utils.masks import NoOutputMask LOGGER = logging.getLogger(__name__) @@ -82,6 +85,12 @@ def __init__( self.latlons_data = graph_data[config.graph.data].x self.loss_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() + if config.model.get("output_mask", None) is not None: + self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask]) + else: + self.output_mask = NoOutputMask() + self.loss_weights = self.output_mask.apply(self.loss_weights, dim=0, fill_value=0.0) + self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled self.metric_ranges, self.metric_ranges_validation, loss_scaling = self.metrics_loss_scaling( @@ -202,6 +211,8 @@ def advance_input( self.data_indices.internal_model.output.prognostic, ] + x[:, -1] = self.output_mask.rollout_boundary(x[:, -1], batch[:, -1], self.data_indices) + # get new "constants" needed for time-varying fields x[:, -1, :, :, self.data_indices.internal_model.input.forcing] = batch[ :, diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index f48b9467..b772eb2a 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -1,11 +1,12 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# + from __future__ import annotations @@ -68,6 +69,10 @@ def __init__(self, config: DictConfig) -> None: self.config.training.run_id = self.run_id LOGGER.info("Run id: %s", self.config.training.run_id) + + # Get the server2server lineage + self._get_server2server_lineage() + # Update paths to contain the run ID self._update_paths() @@ -147,7 +152,7 @@ def model(self) -> GraphForecaster: @rank_zero_only def _get_mlflow_run_id(self) -> str: run_id = self.mlflow_logger.run_id - # for resumed runs or offline runs logging this can be uesful + # for resumed runs or offline runs logging this can be useful LOGGER.info("Mlflow Run id: %s", run_id) return run_id @@ -188,18 +193,21 @@ def last_checkpoint(self) -> str | None: if not self.start_from_checkpoint: return None + fork_id = self.fork_run_server2server or self.config.training.fork_run_id checkpoint = Path( self.config.hardware.paths.checkpoints.parent, - self.config.training.fork_run_id or self.run_id, + fork_id or self.lineage_run, self.config.hardware.files.warm_start or "last.ckpt", ) - # Check if the last checkpoint exists if Path(checkpoint).exists(): LOGGER.info("Resuming training from last checkpoint: %s", checkpoint) return checkpoint - LOGGER.warning("Could not find last checkpoint: %s", checkpoint) + if rank_zero_only.rank == 0: + msg = "Could not find last checkpoint: %s", checkpoint + raise RuntimeError(msg) + return None @cached_property @@ -252,10 +260,13 @@ def profiler(self) -> PyTorchProfiler | None: def loggers(self) -> list: loggers = [] if self.config.diagnostics.log.wandb.enabled: + LOGGER.info("W&B logger enabled") loggers.append(self.wandb_logger) if self.config.diagnostics.log.tensorboard.enabled: + LOGGER.info("TensorBoard logger enabled") loggers.append(self.tensorboard_logger) if self.config.diagnostics.log.mlflow.enabled: + LOGGER.info("MLFlow logger enabled") loggers.append(self.mlflow_logger) return loggers @@ -291,17 +302,33 @@ def _log_information(self) -> None: LOGGER.debug("Effective learning rate: %.3e", total_number_of_model_instances * self.config.training.lr.rate) LOGGER.debug("Rollout window length: %d", self.config.training.rollout.start) + def _get_server2server_lineage(self) -> None: + """Get the server2server lineage.""" + self.parent_run_server2server = None + self.fork_run_server2server = None + if self.config.diagnostics.log.mlflow.enabled: + self.parent_run_server2server = self.mlflow_logger._parent_run_server2server + LOGGER.info("Parent run server2server: %s", self.parent_run_server2server) + self.fork_run_server2server = self.mlflow_logger._fork_run_server2server + LOGGER.info("Fork run server2server: %s", self.fork_run_server2server) + def _update_paths(self) -> None: """Update the paths in the configuration.""" + self.lineage_run = None if self.run_id: # when using mlflow only rank0 will have a run_id except when resuming runs # Multi-gpu new runs or forked runs - only rank 0 # Multi-gpu resumed runs - all ranks - self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, self.run_id) - self.config.hardware.paths.plots = Path(self.config.hardware.paths.plots, self.run_id) + self.lineage_run = self.parent_run_server2server or self.run_id + self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, self.lineage_run) + self.config.hardware.paths.plots = Path(self.config.hardware.paths.plots, self.lineage_run) elif self.config.training.fork_run_id: + # WHEN USING MANY NODES/GPUS + self.lineage_run = self.parent_run_server2server or self.config.training.fork_run_id # Only rank non zero in the forked run will go here - parent_run = self.config.training.fork_run_id - self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, parent_run) + self.config.hardware.paths.checkpoints = Path(self.config.hardware.paths.checkpoints, self.lineage_run) + + LOGGER.info("Checkpoints path: %s", self.config.hardware.paths.checkpoints) + LOGGER.info("Plots path: %s", self.config.hardware.paths.plots) @cached_property def strategy(self) -> DDPGroupStrategy: diff --git a/src/anemoi/training/utils/checkpoint.py b/src/anemoi/training/utils/checkpoint.py index 697a92de..ddb5a1c8 100644 --- a/src/anemoi/training/utils/checkpoint.py +++ b/src/anemoi/training/utils/checkpoint.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations from pathlib import Path diff --git a/src/anemoi/training/utils/jsonify.py b/src/anemoi/training/utils/jsonify.py index ddf9b86c..c44092b1 100644 --- a/src/anemoi/training/utils/jsonify.py +++ b/src/anemoi/training/utils/jsonify.py @@ -1,11 +1,13 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import datetime from pathlib import Path diff --git a/src/anemoi/training/utils/masks.py b/src/anemoi/training/utils/masks.py new file mode 100644 index 00000000..fd0581a0 --- /dev/null +++ b/src/anemoi/training/utils/masks.py @@ -0,0 +1,115 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING + +import numpy as np +import torch + +if TYPE_CHECKING: + from anemoi.models.data_indices.collection import IndexCollection + + +class BaseMask: + """Base class for masking model output.""" + + @abstractmethod + def apply(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + error_message = "Method `apply` must be implemented in subclass." + raise NotImplementedError(error_message) + + @abstractmethod + def rollout_boundary(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + error_message = "Method `rollout_boundary` must be implemented in subclass." + raise NotImplementedError(error_message) + + +class Boolean1DMask(BaseMask): + """1D Boolean mask.""" + + def __init__(self, values: torch.Tensor) -> None: + self.mask = values.bool().squeeze() + + def broadcast_like(self, x: torch.Tensor, dim: int) -> torch.Tensor: + assert x.shape[dim] == len( + self.mask, + ), f"Dimension mismatch: dimension {dim} has size {x.shape[dim]}, but mask length is {len(self.mask)}." + target_shape = [1 for _ in range(x.ndim)] + target_shape[dim] = len(self.mask) + mask = self.mask.reshape(target_shape) + return mask.to(x.device) + + @staticmethod + def _fill_masked_tensor(x: torch.Tensor, mask: torch.Tensor, fill_value: float | torch.Tensor) -> torch.Tensor: + if isinstance(fill_value, torch.Tensor): + return x.masked_scatter(mask, fill_value) + return x.masked_fill(mask, fill_value) + + def apply(self, x: torch.Tensor, dim: int, fill_value: float | torch.Tensor = np.nan) -> torch.Tensor: + """Apply the mask to the input tensor. + + Parameters + ---------- + x : torch.Tensor + The input tensor to be masked. + dim : int + The dimension along which to apply the mask. + fill_value : float | torch.Tensor, optional + The value to fill in the masked positions, by default np.nan. + + Returns + ------- + torch.Tensor + The masked tensor with fill_value in the positions where the mask is False. + """ + mask = self.broadcast_like(x, dim) + return Boolean1DMask._fill_masked_tensor(x, ~mask, fill_value) + + def rollout_boundary( + self, + pred_state: torch.Tensor, + true_state: torch.Tensor, + data_indices: IndexCollection, + ) -> torch.Tensor: + """Rollout the boundary forcing. + + Parameters + ---------- + pred_state : torch.Tensor + The predicted state tensor of shape (bs, ens, latlon, nvar) + true_state : torch.Tensor + The true state tensor of shape (bs, ens, latlon, nvar) + data_indices : IndexCollection + Collection of data indices. + + Returns + ------- + torch.Tensor + The updated predicted state tensor with boundary forcing applied. + """ + pred_state[..., data_indices.model.input.prognostic] = self.apply( + pred_state[..., data_indices.model.input.prognostic], + dim=2, + fill_value=true_state[..., data_indices.data.output.prognostic], + ) + + return pred_state + + +class NoOutputMask(BaseMask): + """No output mask.""" + + def apply(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # noqa: ARG002 + return x + + def rollout_boundary(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # noqa: ARG002 + return x diff --git a/src/anemoi/training/utils/mlflow_sync.py b/src/anemoi/training/utils/mlflow_sync.py index 52ceee2f..534b4e3f 100644 --- a/src/anemoi/training/utils/mlflow_sync.py +++ b/src/anemoi/training/utils/mlflow_sync.py @@ -1,15 +1,22 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging import os +import shutil import tempfile from itertools import starmap from pathlib import Path +from urllib.parse import urlparse + +import mlflow.entities def export_log_output_file_path() -> tempfile._TemporaryFileWrapper: @@ -26,11 +33,21 @@ def export_log_output_file_path() -> tempfile._TemporaryFileWrapper: Path(tmpdir).mkdir(parents=True, exist_ok=True) temp = tempfile.NamedTemporaryFile(dir=tmpdir, prefix=f"{user}_") # noqa: SIM115 os.environ["MLFLOW_EXPORT_IMPORT_LOG_OUTPUT_FILE"] = temp.name + os.environ["MLFLOW_EXPORT_IMPORT_TMP_DIRECTORY"] = tmpdir return temp +def close_and_clean_temp(server2server: str, artifact_path: Path) -> None: + temp.close() + os.environ.pop("MLFLOW_EXPORT_IMPORT_LOG_OUTPUT_FILE") + os.environ.pop("MLFLOW_EXPORT_IMPORT_TMP_DIRECTORY") + if server2server: + shutil.rmtree(artifact_path) + + temp = export_log_output_file_path() + import mlflow # noqa: E402 from mlflow.entities import RunStatus # noqa: E402 from mlflow.entities import RunTag # noqa: E402 @@ -110,22 +127,22 @@ def __init__( LOGGER.setLevel(self.log_level) @staticmethod - def update_run_id(params: dict, key: str, new_run_id: str, offline_run_id: str) -> dict: + def update_run_id(params: dict, key: str, new_run_id: str, src_run_id: str, run_type: str) -> dict: params[f"config.training.{key}"] = new_run_id - params[f"config.training.offline_{key}"] = offline_run_id + params[f"config.training.{run_type}_{key}"] = src_run_id if key == "run_id": - params[f"metadata.offline_{key}"] = offline_run_id + params[f"metadata.{run_type}_{key}"] = src_run_id params[f"metadata.{key}"] = new_run_id return params - def update_parent_run_info(self, tags: dict, tag_key: str, tag_dest: str, dst_run_id: str) -> dict: + def update_parent_run_info(self, tags: dict, tag_key: str, tag_dest: str, dst_run_id: str, run_type: str) -> dict: mlflow.set_tracking_uri(self.dest_tracking_uri) # Check if there is already a parent run in the destination tracking uri runs = mlflow.search_runs( experiment_ids=mlflow.get_experiment_by_name(self.experiment_name).experiment_id, - filter_string=f"params.metadata.offline_run_id = '{tags[tag_key]}'", + filter_string=f"params.metadata.{run_type}_run_id = '{tags[tag_key]}'", ) if not runs.empty: @@ -139,19 +156,128 @@ def update_parent_run_info(self, tags: dict, tag_key: str, tag_dest: str, dst_ru tags[tag_key] = new_parent_run_id # update new online parent run_id return tags - def check_run_is_logged(self, status: str = "FINISHED") -> bool: + def check_run_is_logged(self, status: str = "FINISHED", server2server: bool = False) -> bool: """Blocks sync if top-level parent run or single runs are unavailable.""" run_logged = False if status == "FINISHED": mlflow.set_tracking_uri(self.dest_tracking_uri) - synced_runs = mlflow.search_runs( - experiment_ids=mlflow.get_experiment_by_name(self.experiment_name).experiment_id, - filter_string=f"params.metadata.offline_run_id = '{self.run_id}'", - ) - if not synced_runs.empty: # single run (no child) already logged - run_logged = True + experiment = mlflow.get_experiment_by_name(self.experiment_name) + run_type = "server2server" if server2server else "offline" + if experiment: + synced_runs = mlflow.search_runs( + experiment_ids=experiment.experiment_id, + filter_string=f"params.metadata.{run_type}_run_id = '{self.run_id}'", + ) + if not synced_runs.empty: # single run (no child) already logged + run_logged = True return run_logged + def _check_source_tracking_uri(self) -> bool: + parsed_url = urlparse(self.source_tracking_uri) + return all([parsed_url.scheme, parsed_url.netloc]) # True if source_tracking_uri is a remote server + + def _get_dst_experiment_id(self, dest_mlflow_client: str) -> str: + experiment = dest_mlflow_client.get_experiment_by_name(self.experiment_name) + if not experiment: + return dest_mlflow_client.create_experiment(self.experiment_name) + return experiment.experiment_id + + def _get_artifacts_path(self, server2server: str, run: mlflow.entities.Run) -> Path: + if server2server: + # Download each artifact + temp_dir = os.getenv("MLFLOW_EXPORT_IMPORT_TMP_DIRECTORY") + artifact_path = Path(temp_dir, run.info.run_id) + artifact_path.mkdir(parents=True, exist_ok=True) + else: + artifact_path = Path(self.source_tracking_uri, run.info.experiment_id, run.info.run_id, "artifacts") + + return artifact_path + + def _download_artifacts( + self, + client: mlflow.tracking.client.MlflowClient, + run_id: mlflow.entities.Run, + artifact_path: Path, + ) -> None: + + mlflow.set_tracking_uri(self.source_tracking_uri) # OTHERWISE IT WILL NOT WORK + artifacts = client.list_artifacts(run_id) + LOGGER.info("Downloading artifacts %s for run %s to %s", len(artifacts), run_id, artifact_path) + for artifact in artifacts: + # Download artifact file from the server + mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact.path, dst_path=artifact_path) + + def _update_params_tags_runs( + self, + params: dict, + tags: dict, + dst_run_id: str, + src_run_id: str, + run_type: str = "offline", + ) -> (dict, dict): + + if (params["config.training.fork_run_id"] == "None") and (params["metadata.run_id"] == src_run_id): + params = self.update_run_id( + params, + "run_id", + new_run_id=dst_run_id, + src_run_id=src_run_id, + run_type=run_type, + ) + + elif "forkedRun" in tags: + try: + tags = self.update_parent_run_info( + tags=tags, + tag_key="forkedRunId", + tag_dest=f"{run_type}.forkedRunId", + dst_run_id=dst_run_id, + run_type=run_type, + ) + params = self.update_run_id( + params, + "fork_run_id", + new_run_id=tags["forkedRunId"], + src_run_id=tags[f"{run_type}.forkedRunId"], + run_type=run_type, + ) + params = self.update_run_id( + params, + "run_id", + new_run_id=dst_run_id, + src_run_id=src_run_id, + run_type=run_type, + ) + + except AttributeError: + LOGGER.warning("No forked run parent found") + + elif "resumedRun" in tags: + try: + tags = self.update_parent_run_info( + tags=tags, + tag_key="mlflow.parentRunId", + tag_dest=f"mlflow.{run_type}.parentRunId", + dst_run_id=dst_run_id, + run_type=run_type, + ) + params = self.update_run_id( + params, + "run_id", + new_run_id=tags["mlflow.parentRunId"], + src_run_id=tags[f"mlflow.{run_type}.parentRunId"], + run_type=run_type, + ) + + # in the offline case that's the local folder name for the resumed run + # in the server2server case that's the source server run_id of the resumed run + params[f"config.training.{run_type}_self_run_id"] = src_run_id + + except AttributeError: + LOGGER.warning("No parent run found") + + return params, tags + def sync( self, ) -> None: @@ -161,6 +287,7 @@ def sync( http_client = create_http_client(dest_mlflow_client) # GET SOURCE RUN ## run = src_mlflow_client.get_run(self.run_id) + server2server = self._check_source_tracking_uri() run_logged = self.check_run_is_logged(status=run.info.status) if run_logged: LOGGER.info("Run already imported %s into experiment %s", self.run_id, self.experiment_name) @@ -174,6 +301,7 @@ def sync( ) return + msg = { "run_id": run.info.run_id, "lifecycle_stage": run.info.lifecycle_stage, @@ -184,69 +312,44 @@ def sync( run_info = mlflow_utils.strip_underscores(run.info) src_user_id = run_info["user_id"] - exp = dest_mlflow_client.get_experiment_by_name(self.experiment_name) - dst_run = dest_mlflow_client.create_run(exp.experiment_id) + exp_id = self._get_dst_experiment_id(dest_mlflow_client=dest_mlflow_client) + dst_run = dest_mlflow_client.create_run(exp_id) dst_run_id = dst_run.info.run_id tags = dict(sorted(run.data.tags.items())) params = run.data.params # So far there is no easy way to force mlflow to use a specific run_id, that means - # that when we online sync the offline runs those will have run run_ids. To keep + # that when we online sync the offline runs those will have different run_ids. To keep # track of online and offline governance in that case we update run_ids info - if (params["config.training.fork_run_id"] == "None") and (params["metadata.run_id"] == run.info.run_id): - params = self.update_run_id(params, "run_id", new_run_id=dst_run_id, offline_run_id=run.info.run_id) + artifact_path = self._get_artifacts_path(server2server, run) - elif "forkedRun" in tags: - try: - tags = self.update_parent_run_info( - tags=tags, - tag_key="forkedRunId", - tag_dest="offline.forkedRunId", - dst_run_id=dst_run_id, - ) - params = self.update_run_id( - params, - "fork_run_id", - new_run_id=tags["forkedRunId"], - offline_run_id=tags["offline.forkedRunId"], - ) - params = self.update_run_id(params, "run_id", new_run_id=dst_run_id, offline_run_id=run.info.run_id) - - except AttributeError: - LOGGER.warning("No forked run parent found") - - elif "resumedRun" in tags: - try: - tags = self.update_parent_run_info( - tags=tags, - tag_key="mlflow.parentRunId", - tag_dest="mlflow.offline.parentRunId", - dst_run_id=dst_run_id, - ) - params = self.update_run_id( - params, - "run_id", - new_run_id=tags["mlflow.parentRunId"], - offline_run_id=tags["mlflow.offline.parentRunId"], - ) - - params["config.training.offline_run_id_folder"] = run.info.run_id - - except AttributeError: - LOGGER.warning("No parent run found") + if server2server: + tags["server2server"] = "True" + self._download_artifacts(src_mlflow_client, run.info.run_id, artifact_path) + params, tags = self._update_params_tags_runs( + params, + tags, + dst_run_id, + run.info.run_id, + run_type="server2server", + ) - tags["offlineRun"] = "True" + else: + tags["offlineRun"] = "True" + params, tags = self._update_params_tags_runs(params, tags, dst_run_id, run.info.run_id, run_type="offline") src_run_dct = { - "params": run.data.params, + "params": params, "metrics": _get_metrics_with_steps(src_mlflow_client, run), "tags": tags, "inputs": _inputs_to_dict(run.inputs), } try: + LOGGER.info("Starting to export run data") + import_run_data( dest_mlflow_client, src_run_dct, @@ -255,19 +358,22 @@ def sync( ) _import_inputs(http_client, src_run_dct, dst_run_id) - path = Path(self.source_tracking_uri, run.info.experiment_id, self.run_id, "artifacts") - if path.exists(): - mlflow.set_tracking_uri(self.dest_tracking_uri) - dest_mlflow_client.log_artifacts(dst_run_id, path) + mlflow.set_tracking_uri(self.dest_tracking_uri) + dest_mlflow_client.log_artifacts(dst_run_id, artifact_path) dest_mlflow_client.set_terminated(dst_run_id, RunStatus.to_string(RunStatus.FINISHED)) - except Exception as e: + except BaseException: dest_mlflow_client.set_terminated(dst_run_id, RunStatus.to_string(RunStatus.FAILED)) import traceback traceback.print_exc() - raise Exception(e, "Importing run %s of experiment %s failed", dst_run_id, exp.name) from e # noqa: TRY002 + LOGGER.exception( + "Importing run %s of experiment %s failed", + dst_run_id, + self.experiment_name, + ) - LOGGER.info("Imported run %s into experiment %s", dst_run_id, self.experiment_name) + finally: + close_and_clean_temp(server2server, artifact_path) - temp.close() + LOGGER.info("Imported run %s into experiment %s", dst_run_id, self.experiment_name) diff --git a/src/anemoi/training/utils/seeding.py b/src/anemoi/training/utils/seeding.py index d766bd1a..3b4afd47 100644 --- a/src/anemoi/training/utils/seeding.py +++ b/src/anemoi/training/utils/seeding.py @@ -1,7 +1,8 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/utils/usable_indices.py b/src/anemoi/training/utils/usable_indices.py index 7bdd5cbd..0d97f25f 100644 --- a/src/anemoi/training/utils/usable_indices.py +++ b/src/anemoi/training/utils/usable_indices.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import numpy as np diff --git a/src/hydra_plugins/anemoi_searchpath/anemoi_searchpath_plugin.py b/src/hydra_plugins/anemoi_searchpath/anemoi_searchpath_plugin.py index 75efae9a..db44dff7 100644 --- a/src/hydra_plugins/anemoi_searchpath/anemoi_searchpath_plugin.py +++ b/src/hydra_plugins/anemoi_searchpath/anemoi_searchpath_plugin.py @@ -1,11 +1,13 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging import os from pathlib import Path diff --git a/tests/conftest.py b/tests/conftest.py index 711b1e2d..46163e8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,8 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/tests/diagnostics/mlflow/test_auth.py b/tests/diagnostics/mlflow/test_auth.py index 9f3a709b..329f7217 100644 --- a/tests/diagnostics/mlflow/test_auth.py +++ b/tests/diagnostics/mlflow/test_auth.py @@ -1,9 +1,12 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import os diff --git a/tests/diagnostics/mlflow/test_client.py b/tests/diagnostics/mlflow/test_client.py index 93c0aeeb..f6dedbce 100644 --- a/tests/diagnostics/mlflow/test_client.py +++ b/tests/diagnostics/mlflow/test_client.py @@ -1,9 +1,12 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations from typing import TYPE_CHECKING diff --git a/tests/diagnostics/test_checkpoint.py b/tests/diagnostics/test_checkpoint.py index e2dce376..63e6ccc9 100644 --- a/tests/diagnostics/test_checkpoint.py +++ b/tests/diagnostics/test_checkpoint.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from __future__ import annotations import datetime diff --git a/tests/hydra/test_search_path_plugins.py b/tests/hydra/test_search_path_plugins.py index dd981b66..48666502 100644 --- a/tests/hydra/test_search_path_plugins.py +++ b/tests/hydra/test_search_path_plugins.py @@ -1,11 +1,13 @@ -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + from hydra import initialize from hydra.core.global_hydra import GlobalHydra from hydra.core.plugins import Plugins diff --git a/tests/train/test_loss_scaling.py b/tests/train/test_loss_scaling.py index 84ca2189..2da5ae00 100644 --- a/tests/train/test_loss_scaling.py +++ b/tests/train/test_loss_scaling.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import pytest import torch from _pytest.fixtures import SubRequest diff --git a/tests/utils/test_usable_indices.py b/tests/utils/test_usable_indices.py index 6bc5c83f..0aff358a 100644 --- a/tests/utils/test_usable_indices.py +++ b/tests/utils/test_usable_indices.py @@ -1,10 +1,13 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import numpy as np from anemoi.training.utils.usable_indices import get_usable_indices