From c686535222988ea6f4d6ad1bc3ad8c1e6caaf254 Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Mon, 2 Dec 2024 14:44:39 +0100 Subject: [PATCH] Add fixture for caching everest test-data example --- tests/everest/conftest.py | 52 ++++++++++ .../entry_points/test_visualization_entry.py | 25 ++--- tests/everest/test_api_snapshots.py | 97 ++++++++----------- 3 files changed, 99 insertions(+), 75 deletions(-) diff --git a/tests/everest/conftest.py b/tests/everest/conftest.py index 8de987e6823..cba6e2aa6bc 100644 --- a/tests/everest/conftest.py +++ b/tests/everest/conftest.py @@ -1,5 +1,6 @@ import os import shutil +import tempfile from copy import deepcopy from pathlib import Path from typing import Callable, Dict, Iterator, Optional, Union @@ -8,6 +9,8 @@ from ert.config import QueueSystem from ert.ensemble_evaluator import EvaluatorServerConfig +from ert.run_models.everest_run_model import EverestRunModel +from everest.config import EverestConfig from everest.config.control_config import ControlConfig from tests.everest.utils import relpath @@ -144,3 +147,52 @@ def create_evaluator_server_config(run_model): ) return create_evaluator_server_config + + +@pytest.fixture +def cached_example(pytestconfig, evaluator_server_config_generator): + cache = pytestconfig.cache + + def run_config(test_data_case: str): + if cache.get(f"cached_example:{test_data_case}", None) is None: + my_tmpdir = Path(tempfile.mkdtemp()) + config_path = ( + Path(__file__) / f"../../../test-data/everest/{test_data_case}" + ).resolve() + config_file = config_path.name + + shutil.copytree(config_path.parent, my_tmpdir / "everest") + config = EverestConfig.load_file(my_tmpdir / "everest" / config_file) + run_model = EverestRunModel.create(config) + evaluator_server_config = evaluator_server_config_generator(run_model) + try: + run_model.run_experiment(evaluator_server_config) + except Exception as e: + raise Exception(f"Failed running {config_path} with error: {e}") from e + + result_path = my_tmpdir / "everest" + + optimal_result = run_model.result + optimal_result_json = { + "batch": optimal_result.batch, + "controls": optimal_result.controls, + "total_objective": optimal_result.total_objective, + } + + cache.set( + f"cached_example:{test_data_case}", + (str(result_path), config_file, optimal_result_json), + ) + + result_path, config_file, optimal_result_json = cache.get( + f"cached_example:{test_data_case}", (None, None, None) + ) + + copied_tmpdir = tempfile.mkdtemp() + shutil.copytree(result_path, Path(copied_tmpdir) / "everest") + copied_path = str(Path(copied_tmpdir) / "everest") + os.chdir(copied_path) + + return copied_path, config_file, optimal_result_json + + return run_config diff --git a/tests/everest/entry_points/test_visualization_entry.py b/tests/everest/entry_points/test_visualization_entry.py index fbeabdf0e72..8bdc49ad717 100644 --- a/tests/everest/entry_points/test_visualization_entry.py +++ b/tests/everest/entry_points/test_visualization_entry.py @@ -1,34 +1,21 @@ import sys -from unittest.mock import PropertyMock, patch +from pathlib import Path +from unittest.mock import patch import pytest from everest.bin.visualization_script import visualization_entry -from everest.config import EverestConfig from everest.detached import ServerStatus -from tests.everest.utils import capture_streams, relpath - -CONFIG_PATH = relpath( - "..", "..", "test-data", "everest", "math_func", "config_advanced.yml" -) -CACHED_SEBA_FOLDER = relpath("test_data", "cached_results_config_advanced") +from tests.everest.utils import capture_streams -@patch.object( - EverestConfig, - "optimization_output_dir", - new_callable=PropertyMock, - return_value=CACHED_SEBA_FOLDER, -) @patch( "everest.bin.visualization_script.everserver_status", return_value={"status": ServerStatus.completed}, ) @pytest.mark.skipif(sys.version_info.major < 3, reason="requires python3 or higher") -def test_visualization_entry( - opt_dir_mock, - server_status_mock, -): +def test_visualization_entry(_, cached_example): + config_path, config_file, _ = cached_example("math_func/config_advanced.yml") with capture_streams() as (out, _): - visualization_entry([CONFIG_PATH]) + visualization_entry([str(Path(config_path) / config_file)]) assert "No visualization plugin installed!" in out.getvalue() diff --git a/tests/everest/test_api_snapshots.py b/tests/everest/test_api_snapshots.py index 4b2d1fcfac2..d35ce482b57 100644 --- a/tests/everest/test_api_snapshots.py +++ b/tests/everest/test_api_snapshots.py @@ -1,5 +1,6 @@ import json from datetime import datetime, timedelta +from pathlib import Path from typing import Any, Dict import orjson @@ -7,7 +8,7 @@ import pytest from ert.config import SummaryConfig -from ert.run_models.everest_run_model import EverestRunModel +from ert.storage import open_storage from everest.api import EverestDataAPI from everest.config import EverestConfig @@ -62,24 +63,11 @@ def make_api_snapshot(api) -> Dict[str, Any]: "config_stddev.yml", ], ) -def test_api_snapshots( - config_file, - copy_math_func_test_data_to_tmp, - evaluator_server_config_generator, - snapshot, -): - config = EverestConfig.load_file(config_file) - run_model = EverestRunModel.create(config) - evaluator_server_config = evaluator_server_config_generator(run_model) - run_model.run_experiment(evaluator_server_config) - - optimal_result = run_model.result - optimal_result_json = { - "batch": optimal_result.batch, - "controls": optimal_result.controls, - "total_objective": optimal_result.total_objective, - } - +def test_api_snapshots(config_file, snapshot, cached_example): + config_path, config_file, optimal_result_json = cached_example( + f"math_func/{config_file}" + ) + config = EverestConfig.load_file(Path(config_path) / config_file) api = EverestDataAPI(config) json_snapshot = make_api_snapshot(api) json_snapshot["optimal_result_json"] = optimal_result_json @@ -94,43 +82,40 @@ def test_api_snapshots( snapshot.assert_match(snapshot_str, "snapshot.json") -def test_api_summary_snapshot( - copy_math_func_test_data_to_tmp, evaluator_server_config_generator, snapshot -): - config = EverestConfig.load_file("config_minimal.yml") - run_model = EverestRunModel.create(config) - evaluator_server_config = evaluator_server_config_generator(run_model) - run_model.run_experiment(evaluator_server_config) - - # Save some summary data to each ensemble - experiment = next(run_model._storage.experiments) - - response_config = experiment.response_configuration - response_config["summary"] = SummaryConfig(keys=["*"]) - - experiment._storage._write_transaction( - experiment._path / experiment._responses_file, - json.dumps( - {c.response_type: c.to_dict() for c in response_config.values()}, - default=str, - indent=2, - ).encode("utf-8"), - ) - - smry_data = polars.DataFrame( - { - "response_key": ["FOPR", "FOPR", "WOPR", "WOPR", "FOPT", "FOPT"], - "time": polars.Series( - [datetime(2000, 1, 1) + timedelta(days=i) for i in range(6)] - ).dt.cast_time_unit("ms"), - "values": polars.Series( - [0.2, 0.2, 1.0, 1.1, 3.3, 3.3], dtype=polars.Float32 - ), - } - ) - for ens in experiment.ensembles: - for real in range(ens.ensemble_size): - ens.save_response("summary", smry_data.clone(), real) +def test_api_summary_snapshot(snapshot, cached_example): + config_path, config_file, _ = cached_example("math_func/config_minimal.yml") + config = EverestConfig.load_file(Path(config_path) / config_file) + + with open_storage(config.storage_dir, mode="w") as storage: + # Save some summary data to each ensemble + experiment = next(storage.experiments) + + response_config = experiment.response_configuration + response_config["summary"] = SummaryConfig(keys=["*"]) + + experiment._storage._write_transaction( + experiment._path / experiment._responses_file, + json.dumps( + {c.response_type: c.to_dict() for c in response_config.values()}, + default=str, + indent=2, + ).encode("utf-8"), + ) + + smry_data = polars.DataFrame( + { + "response_key": ["FOPR", "FOPR", "WOPR", "WOPR", "FOPT", "FOPT"], + "time": polars.Series( + [datetime(2000, 1, 1) + timedelta(days=i) for i in range(6)] + ).dt.cast_time_unit("ms"), + "values": polars.Series( + [0.2, 0.2, 1.0, 1.1, 3.3, 3.3], dtype=polars.Float32 + ), + } + ) + for ens in experiment.ensembles: + for real in range(ens.ensemble_size): + ens.save_response("summary", smry_data.clone(), real) api = EverestDataAPI(config) dicts = polars.from_pandas(api.summary_values()).to_dicts()