From 36897088f55cae2a38e1985cac97d4fb6eff7b7d Mon Sep 17 00:00:00 2001 From: Frode Aarstad Date: Fri, 13 Dec 2024 09:16:41 +0100 Subject: [PATCH] Stop using cached seba data --- .../entry_points/test_config_branch_entry.py | 58 ++++++------------- 1 file changed, 17 insertions(+), 41 deletions(-) diff --git a/tests/everest/entry_points/test_config_branch_entry.py b/tests/everest/entry_points/test_config_branch_entry.py index 4207f73d247..5e5f42d22b7 100644 --- a/tests/everest/entry_points/test_config_branch_entry.py +++ b/tests/everest/entry_points/test_config_branch_entry.py @@ -1,41 +1,27 @@ import difflib from os.path import exists -from unittest.mock import PropertyMock, patch +from pathlib import Path from seba_sqlite.snapshot import SebaSnapshot from everest.bin.config_branch_script import config_branch_entry -from everest.config import EverestConfig from everest.config_file_loader import load_yaml from everest.config_keys import ConfigKeys as CK -from tests.everest.utils import relpath -CONFIG_FILE = "config_advanced.yml" -CACHED_SEBA_FOLDER = relpath("test_data", "cached_results_config_advanced") +def test_config_branch_entry(cached_example): + path, _, _ = cached_example("math_func/config_advanced.yml") -# @patch.object(EverestConfig, "optimization_output_dir", new_callable=PropertyMock) -@patch.object( - EverestConfig, - "optimization_output_dir", - new_callable=PropertyMock, - return_value=CACHED_SEBA_FOLDER, -) -def test_config_branch_entry(get_opt_output_dir_mock, copy_math_func_test_data_to_tmp): - new_config_file_name = "new_restart_config.yml" - batch_id = 1 + config_branch_entry(["config_advanced.yml", "new_restart_config.yml", "-b", "1"]) - config_branch_entry([CONFIG_FILE, new_config_file_name, "-b", str(batch_id)]) + assert exists("new_restart_config.yml") - get_opt_output_dir_mock.assert_called_once() - assert exists(new_config_file_name) - - old_config = load_yaml(CONFIG_FILE) + old_config = load_yaml("config_advanced.yml") old_controls = old_config[CK.CONTROLS] assert CK.INITIAL_GUESS in old_controls[0] - new_config = load_yaml(new_config_file_name) + new_config = load_yaml("new_restart_config.yml") new_controls = new_config[CK.CONTROLS] assert CK.INITIAL_GUESS not in new_controls[0] @@ -44,9 +30,9 @@ def test_config_branch_entry(get_opt_output_dir_mock, copy_math_func_test_data_t opt_controls = {} - snapshot = SebaSnapshot(CACHED_SEBA_FOLDER) + snapshot = SebaSnapshot(Path(path) / "everest_output" / "optimization_output") for opt_data in snapshot._optimization_data(): - if opt_data.batch_id == batch_id: + if opt_data.batch_id == 1: opt_controls = opt_data.controls new_controls_initial_guesses = { @@ -57,36 +43,26 @@ def test_config_branch_entry(get_opt_output_dir_mock, copy_math_func_test_data_t assert new_controls_initial_guesses == opt_control_val_for_batch_id -@patch.object( - EverestConfig, - "optimization_output_dir", - new_callable=PropertyMock, - return_value=CACHED_SEBA_FOLDER, -) -def test_config_branch_preserves_config_section_order( - get_opt_output_dir_mock, copy_math_func_test_data_to_tmp -): - new_config_file_name = "new_restart_config.yml" - batch_id = 1 +def test_config_branch_preserves_config_section_order(cached_example): + path, _, _ = cached_example("math_func/config_advanced.yml") - config_branch_entry([CONFIG_FILE, new_config_file_name, "-b", str(batch_id)]) + config_branch_entry(["config_advanced.yml", "new_restart_config.yml", "-b", "1"]) - get_opt_output_dir_mock.assert_called_once() - assert exists(new_config_file_name) + assert exists("new_restart_config.yml") opt_controls = {} - snapshot = SebaSnapshot(CACHED_SEBA_FOLDER) + snapshot = SebaSnapshot(Path(path) / "everest_output" / "optimization_output") for opt_data in snapshot._optimization_data(): - if opt_data.batch_id == batch_id: + if opt_data.batch_id == 1: opt_controls = opt_data.controls opt_control_val_for_batch_id = {v for k, v in opt_controls.items()} diff_lines = [] with ( - open(CONFIG_FILE, "r", encoding="utf-8") as initial_config, - open(new_config_file_name, "r", encoding="utf-8") as branch_config, + open("config_advanced.yml", "r", encoding="utf-8") as initial_config, + open("new_restart_config.yml", "r", encoding="utf-8") as branch_config, ): diff = difflib.unified_diff( initial_config.readlines(),