Skip to content

Commit

Permalink
Add fixture for caching everest test-data example
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Dec 3, 2024
1 parent 22687ea commit c686535
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 75 deletions.
52 changes: 52 additions & 0 deletions tests/everest/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import shutil
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Callable, Dict, Iterator, Optional, Union
Expand All @@ -8,6 +9,8 @@

from ert.config import QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_models.everest_run_model import EverestRunModel
from everest.config import EverestConfig
from everest.config.control_config import ControlConfig
from tests.everest.utils import relpath

Expand Down Expand Up @@ -144,3 +147,52 @@ def create_evaluator_server_config(run_model):
)

return create_evaluator_server_config


@pytest.fixture
def cached_example(pytestconfig, evaluator_server_config_generator):
cache = pytestconfig.cache

def run_config(test_data_case: str):
if cache.get(f"cached_example:{test_data_case}", None) is None:
my_tmpdir = Path(tempfile.mkdtemp())
config_path = (
Path(__file__) / f"../../../test-data/everest/{test_data_case}"
).resolve()
config_file = config_path.name

shutil.copytree(config_path.parent, my_tmpdir / "everest")
config = EverestConfig.load_file(my_tmpdir / "everest" / config_file)
run_model = EverestRunModel.create(config)
evaluator_server_config = evaluator_server_config_generator(run_model)
try:
run_model.run_experiment(evaluator_server_config)
except Exception as e:
raise Exception(f"Failed running {config_path} with error: {e}") from e

result_path = my_tmpdir / "everest"

optimal_result = run_model.result
optimal_result_json = {
"batch": optimal_result.batch,
"controls": optimal_result.controls,
"total_objective": optimal_result.total_objective,
}

cache.set(
f"cached_example:{test_data_case}",
(str(result_path), config_file, optimal_result_json),
)

result_path, config_file, optimal_result_json = cache.get(
f"cached_example:{test_data_case}", (None, None, None)
)

copied_tmpdir = tempfile.mkdtemp()
shutil.copytree(result_path, Path(copied_tmpdir) / "everest")
copied_path = str(Path(copied_tmpdir) / "everest")
os.chdir(copied_path)

return copied_path, config_file, optimal_result_json

return run_config
25 changes: 6 additions & 19 deletions tests/everest/entry_points/test_visualization_entry.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,21 @@
import sys
from unittest.mock import PropertyMock, patch
from pathlib import Path
from unittest.mock import patch

import pytest

from everest.bin.visualization_script import visualization_entry
from everest.config import EverestConfig
from everest.detached import ServerStatus
from tests.everest.utils import capture_streams, relpath

CONFIG_PATH = relpath(
"..", "..", "test-data", "everest", "math_func", "config_advanced.yml"
)
CACHED_SEBA_FOLDER = relpath("test_data", "cached_results_config_advanced")
from tests.everest.utils import capture_streams


@patch.object(
EverestConfig,
"optimization_output_dir",
new_callable=PropertyMock,
return_value=CACHED_SEBA_FOLDER,
)
@patch(
"everest.bin.visualization_script.everserver_status",
return_value={"status": ServerStatus.completed},
)
@pytest.mark.skipif(sys.version_info.major < 3, reason="requires python3 or higher")
def test_visualization_entry(
opt_dir_mock,
server_status_mock,
):
def test_visualization_entry(_, cached_example):
config_path, config_file, _ = cached_example("math_func/config_advanced.yml")
with capture_streams() as (out, _):
visualization_entry([CONFIG_PATH])
visualization_entry([str(Path(config_path) / config_file)])
assert "No visualization plugin installed!" in out.getvalue()
97 changes: 41 additions & 56 deletions tests/everest/test_api_snapshots.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict

import orjson
import polars
import pytest

from ert.config import SummaryConfig
from ert.run_models.everest_run_model import EverestRunModel
from ert.storage import open_storage
from everest.api import EverestDataAPI
from everest.config import EverestConfig

Expand Down Expand Up @@ -62,24 +63,11 @@ def make_api_snapshot(api) -> Dict[str, Any]:
"config_stddev.yml",
],
)
def test_api_snapshots(
config_file,
copy_math_func_test_data_to_tmp,
evaluator_server_config_generator,
snapshot,
):
config = EverestConfig.load_file(config_file)
run_model = EverestRunModel.create(config)
evaluator_server_config = evaluator_server_config_generator(run_model)
run_model.run_experiment(evaluator_server_config)

optimal_result = run_model.result
optimal_result_json = {
"batch": optimal_result.batch,
"controls": optimal_result.controls,
"total_objective": optimal_result.total_objective,
}

def test_api_snapshots(config_file, snapshot, cached_example):
config_path, config_file, optimal_result_json = cached_example(
f"math_func/{config_file}"
)
config = EverestConfig.load_file(Path(config_path) / config_file)
api = EverestDataAPI(config)
json_snapshot = make_api_snapshot(api)
json_snapshot["optimal_result_json"] = optimal_result_json
Expand All @@ -94,43 +82,40 @@ def test_api_snapshots(
snapshot.assert_match(snapshot_str, "snapshot.json")


def test_api_summary_snapshot(
copy_math_func_test_data_to_tmp, evaluator_server_config_generator, snapshot
):
config = EverestConfig.load_file("config_minimal.yml")
run_model = EverestRunModel.create(config)
evaluator_server_config = evaluator_server_config_generator(run_model)
run_model.run_experiment(evaluator_server_config)

# Save some summary data to each ensemble
experiment = next(run_model._storage.experiments)

response_config = experiment.response_configuration
response_config["summary"] = SummaryConfig(keys=["*"])

experiment._storage._write_transaction(
experiment._path / experiment._responses_file,
json.dumps(
{c.response_type: c.to_dict() for c in response_config.values()},
default=str,
indent=2,
).encode("utf-8"),
)

smry_data = polars.DataFrame(
{
"response_key": ["FOPR", "FOPR", "WOPR", "WOPR", "FOPT", "FOPT"],
"time": polars.Series(
[datetime(2000, 1, 1) + timedelta(days=i) for i in range(6)]
).dt.cast_time_unit("ms"),
"values": polars.Series(
[0.2, 0.2, 1.0, 1.1, 3.3, 3.3], dtype=polars.Float32
),
}
)
for ens in experiment.ensembles:
for real in range(ens.ensemble_size):
ens.save_response("summary", smry_data.clone(), real)
def test_api_summary_snapshot(snapshot, cached_example):
config_path, config_file, _ = cached_example("math_func/config_minimal.yml")
config = EverestConfig.load_file(Path(config_path) / config_file)

with open_storage(config.storage_dir, mode="w") as storage:
# Save some summary data to each ensemble
experiment = next(storage.experiments)

response_config = experiment.response_configuration
response_config["summary"] = SummaryConfig(keys=["*"])

experiment._storage._write_transaction(
experiment._path / experiment._responses_file,
json.dumps(
{c.response_type: c.to_dict() for c in response_config.values()},
default=str,
indent=2,
).encode("utf-8"),
)

smry_data = polars.DataFrame(
{
"response_key": ["FOPR", "FOPR", "WOPR", "WOPR", "FOPT", "FOPT"],
"time": polars.Series(
[datetime(2000, 1, 1) + timedelta(days=i) for i in range(6)]
).dt.cast_time_unit("ms"),
"values": polars.Series(
[0.2, 0.2, 1.0, 1.1, 3.3, 3.3], dtype=polars.Float32
),
}
)
for ens in experiment.ensembles:
for real in range(ens.ensemble_size):
ens.save_response("summary", smry_data.clone(), real)

api = EverestDataAPI(config)
dicts = polars.from_pandas(api.summary_values()).to_dicts()
Expand Down

0 comments on commit c686535

Please sign in to comment.