diff --git a/src/ert/callbacks.py b/src/ert/callbacks.py index 8c14abcce5c..8189dbfab3e 100644 --- a/src/ert/callbacks.py +++ b/src/ert/callbacks.py @@ -6,7 +6,6 @@ from pathlib import Path from ert.config import InvalidResponseFile -from ert.run_arg import RunArg from ert.storage import Ensemble from ert.storage.realization_storage_state import RealizationStorageState @@ -92,63 +91,23 @@ async def _write_responses_to_storage( async def forward_model_ok( - run_arg: RunArg, + run_path: str, + realization: int, + iter: int, + ensemble: Ensemble, ) -> LoadResult: parameters_result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "") response_result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "") try: # We only read parameters after the prior, after that, ERT # handles parameters - if run_arg.itr == 0: + if iter == 0: parameters_result = await _read_parameters( - run_arg.runpath, - run_arg.iens, - run_arg.ensemble_storage, - ) - - if parameters_result.status == LoadStatus.LOAD_SUCCESSFUL: - response_result = await _write_responses_to_storage( - run_arg.runpath, - run_arg.iens, - run_arg.ensemble_storage, + run_path, + realization, + ensemble, ) - except Exception as err: - logger.exception( - f"Failed to load results for realization {run_arg.iens}", - exc_info=err, - ) - parameters_result = LoadResult( - LoadStatus.LOAD_FAILURE, - "Failed to load results for realization " - f"{run_arg.iens}, failed with: {err}", - ) - - final_result = parameters_result - if response_result.status != LoadStatus.LOAD_SUCCESSFUL: - final_result = response_result - run_arg.ensemble_storage.set_failure( - run_arg.iens, RealizationStorageState.LOAD_FAILURE, final_result.message - ) - elif run_arg.ensemble_storage.has_failure(run_arg.iens): - run_arg.ensemble_storage.unset_failure(run_arg.iens) - - return final_result - - -async def load_run_path_realization( - run_path: str, - realization: int, - ensemble: Ensemble, -) -> LoadResult: - response_result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "") - try: - parameters_result = await _read_parameters( - run_path, - realization, - ensemble, - ) - if parameters_result.status == LoadStatus.LOAD_SUCCESSFUL: response_result = await _write_responses_to_storage( run_path, diff --git a/src/ert/libres_facade.py b/src/ert/libres_facade.py index 551ff5187ea..bc516f04455 100644 --- a/src/ert/libres_facade.py +++ b/src/ert/libres_facade.py @@ -21,7 +21,7 @@ from pandas import DataFrame from ert.analysis import AnalysisEvent, SmootherSnapshot, smoother_update -from ert.callbacks import load_run_path_realization +from ert.callbacks import forward_model_ok from ert.config import ( EnkfObservationImplementationType, ErtConfig, @@ -50,7 +50,7 @@ def _load_realization_from_run_path( realization: int, ensemble: Ensemble, ) -> Tuple[LoadResult, int]: - result = asyncio.run(load_run_path_realization(run_path, realization, ensemble)) + result = asyncio.run(forward_model_ok(run_path, realization, 0, ensemble)) return result, realization diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index a3055f83d1a..cba6e9144ca 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -240,7 +240,12 @@ async def _verify_checksum( logger.error(f"Disk synchronization failed for {file_path}") async def _handle_finished_forward_model(self) -> None: - callback_status, status_msg = await forward_model_ok(self.real.run_arg) + callback_status, status_msg = await forward_model_ok( + run_path=self.real.run_arg.runpath, + realization=self.real.run_arg.iens, + iter=self.real.run_arg.itr, + ensemble=self.real.run_arg.ensemble_storage, + ) if self._message: self._message = status_msg else: diff --git a/tests/ert/unit_tests/scheduler/test_job.py b/tests/ert/unit_tests/scheduler/test_job.py index 775e68a0317..ef3e0307394 100644 --- a/tests/ert/unit_tests/scheduler/test_job.py +++ b/tests/ert/unit_tests/scheduler/test_job.py @@ -119,7 +119,7 @@ async def test_job_run_sends_expected_events( realization: Realization, monkeypatch, ): - async def load_result(_): + async def load_result(**_): return (forward_model_ok_result, "") monkeypatch.setattr(ert.scheduler.job, "forward_model_ok", load_result) diff --git a/tests/ert/unit_tests/test_load_forward_model.py b/tests/ert/unit_tests/test_load_forward_model.py index a08322057dd..ba6a5eaa599 100644 --- a/tests/ert/unit_tests/test_load_forward_model.py +++ b/tests/ert/unit_tests/test_load_forward_model.py @@ -11,7 +11,6 @@ from ert.config import ErtConfig from ert.enkf_main import create_run_path from ert.libres_facade import LibresFacade -from ert.run_arg import create_run_arguments from ert.storage import open_storage @@ -290,9 +289,9 @@ def test_that_the_states_are_set_correctly(): assert new_ensemble.has_data() -@pytest.mark.parametrize("iter", [None, 0, 1, 2, 3]) +@pytest.mark.parametrize("itr", [None, 0, 1, 2, 3]) @pytest.mark.usefixtures("use_tmpdir") -def test_loading_from_any_available_iter(storage, run_paths, run_args, iter): +def test_loading_from_any_available_iter(storage, run_paths, run_args, itr): config_text = dedent( """ NUM_REALIZATIONS 1 @@ -308,23 +307,21 @@ def test_loading_from_any_available_iter(storage, run_paths, run_args, iter): ), name="prior", ensemble_size=ert_config.model_config.num_realizations, - iteration=iter if iter is not None else 0, + iteration=itr if itr is not None else 0, ) - run_args = create_run_arguments( - run_paths(ert_config), - [True] * ert_config.model_config.num_realizations, - prior_ensemble, - ) create_run_path( - run_args, - prior_ensemble, - ert_config, - run_paths(ert_config), - ) - run_path = Path( - f"simulations/realization-0/iter-{iter if iter is not None else 0}/" + run_args=run_args(ert_config, prior_ensemble), + ensemble=prior_ensemble, + user_config_file=ert_config.user_config_file, + env_vars=ert_config.env_vars, + forward_model_steps=ert_config.forward_model_steps, + substitutions=ert_config.substitutions, + templates=ert_config.ert_templates, + model_config=ert_config.model_config, + runpaths=run_paths(ert_config), ) + run_path = Path(f"simulations/realization-0/iter-{itr if itr is not None else 0}/") with open(run_path / "response.out", "w", encoding="utf-8") as fout: fout.write("\n".join(["1", "2", "3"])) with open(run_path / "response.out_active", "w", encoding="utf-8") as fout: @@ -333,7 +330,7 @@ def test_loading_from_any_available_iter(storage, run_paths, run_args, iter): facade = LibresFacade.from_config_file("config.ert") run_path_format = str( Path( - f"simulations/realization-/iter-{iter if iter is not None else 0}" + f"simulations/realization-/iter-{itr if itr is not None else 0}" ).resolve() ) facade.load_from_run_path(run_path_format, prior_ensemble, [0])