From 313caf987599c11c054c5ac79306dd259d66e132 Mon Sep 17 00:00:00 2001 From: DanSava Date: Mon, 16 Sep 2024 09:27:57 +0300 Subject: [PATCH] Enable load results manually from any available iteration --- src/ert/callbacks.py | 77 +++++++++--- src/ert/config/gen_data_config.py | 10 +- src/ert/gui/ertwidgets/__init__.py | 2 + src/ert/gui/ertwidgets/textbox.py | 110 ++++++++++++++++++ .../tools/load_results/load_results_panel.py | 73 ++++-------- src/ert/libres_facade.py | 68 +++++------ src/ert/validation/__init__.py | 2 + src/ert/validation/string_definition.py | 39 +++++++ src/everest/bin/everload_script.py | 2 +- .../performance_tests/enkf/test_load_state.py | 4 +- .../cli/test_parameter_sample_types.py | 2 +- .../gui/test_load_results_manually.py | 14 ++- .../scenarios/test_summary_response.py | 4 +- .../ert/unit_tests/storage/create_runpath.py | 2 +- .../ert/unit_tests/test_load_forward_model.py | 67 +++++++++-- tests/ert/unit_tests/test_summary_response.py | 2 +- 16 files changed, 346 insertions(+), 132 deletions(-) create mode 100644 src/ert/gui/ertwidgets/textbox.py create mode 100644 src/ert/validation/string_definition.py diff --git a/src/ert/callbacks.py b/src/ert/callbacks.py index 32b9f639c6d..5d2196c696f 100644 --- a/src/ert/callbacks.py +++ b/src/ert/callbacks.py @@ -4,10 +4,9 @@ import logging import time from pathlib import Path -from typing import Iterable -from ert.config import ParameterConfig, ResponseConfig from ert.run_arg import RunArg +from ert.storage import Ensemble from ert.storage.realization_storage_state import RealizationStorageState from .load_status import LoadResult, LoadStatus @@ -16,24 +15,27 @@ async def _read_parameters( - run_arg: RunArg, parameter_configuration: Iterable[ParameterConfig] + run_path: str, + realization: int, + ensemble: Ensemble, ) -> LoadResult: result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "") error_msg = "" + parameter_configuration = ensemble.experiment.parameter_configuration.values() for config in parameter_configuration: if not config.forward_init: continue try: start_time = time.perf_counter() logger.debug(f"Starting to load parameter: {config.name}") - ds = config.read_from_runpath(Path(run_arg.runpath), run_arg.iens) + ds = config.read_from_runpath(Path(run_path), realization) await asyncio.sleep(0) logger.debug( f"Loaded {config.name}", extra={"Time": f"{(time.perf_counter() - start_time):.4f}s"}, ) start_time = time.perf_counter() - run_arg.ensemble_storage.save_parameters(config.name, run_arg.iens, ds) + ensemble.save_parameters(config.name, realization, ds) await asyncio.sleep(0) logger.debug( f"Saved {config.name} to storage", @@ -42,28 +44,29 @@ async def _read_parameters( except Exception as err: error_msg += str(err) result = LoadResult(LoadStatus.LOAD_FAILURE, error_msg) - logger.warning(f"Failed to load: {run_arg.iens}", exc_info=err) + logger.warning(f"Failed to load: {realization}", exc_info=err) return result async def _write_responses_to_storage( - run_arg: RunArg, response_configs: Iterable[ResponseConfig] + run_path: str, + realization: int, + ensemble: Ensemble, ) -> LoadResult: errors = [] + response_configs = ensemble.experiment.response_configuration.values() for config in response_configs: try: start_time = time.perf_counter() logger.debug(f"Starting to load response: {config.response_type}") - ds = config.read_from_file(run_arg.runpath, run_arg.iens) + ds = config.read_from_file(run_path, realization) await asyncio.sleep(0) logger.debug( f"Loaded {config.response_type}", extra={"Time": f"{(time.perf_counter() - start_time):.4f}s"}, ) start_time = time.perf_counter() - run_arg.ensemble_storage.save_response( - config.response_type, ds, run_arg.iens - ) + ensemble.save_response(config.response_type, ds, realization) await asyncio.sleep(0) logger.debug( f"Saved {config.response_type} to storage", @@ -71,7 +74,7 @@ async def _write_responses_to_storage( ) except ValueError as err: errors.append(str(err)) - logger.warning(f"Failed to write: {run_arg.iens}", exc_info=err) + logger.warning(f"Failed to write: {realization}", exc_info=err) if errors: return LoadResult(LoadStatus.LOAD_FAILURE, "\n".join(errors)) return LoadResult(LoadStatus.LOAD_SUCCESSFUL, "") @@ -87,14 +90,16 @@ async def forward_model_ok( # handles parameters if run_arg.itr == 0: parameters_result = await _read_parameters( - run_arg, - run_arg.ensemble_storage.experiment.parameter_configuration.values(), + run_arg.runpath, + run_arg.iens, + run_arg.ensemble_storage, ) if parameters_result.status == LoadStatus.LOAD_SUCCESSFUL: response_result = await _write_responses_to_storage( - run_arg, - run_arg.ensemble_storage.experiment.response_configuration.values(), + run_arg.runpath, + run_arg.iens, + run_arg.ensemble_storage, ) except Exception as err: @@ -115,3 +120,43 @@ async def forward_model_ok( run_arg.ensemble_storage.unset_failure(run_arg.iens) return final_result + + +async def load_realization( + run_path: str, + realization: int, + ensemble: Ensemble, +) -> LoadResult: + response_result = LoadResult(LoadStatus.LOAD_SUCCESSFUL, "") + try: + parameters_result = await _read_parameters( + run_path, + realization, + ensemble, + ) + + if parameters_result.status == LoadStatus.LOAD_SUCCESSFUL: + response_result = await _write_responses_to_storage( + run_path, + realization, + ensemble, + ) + + except Exception as err: + logger.exception(f"Failed to load results for realization {realization}") + parameters_result = LoadResult( + LoadStatus.LOAD_FAILURE, + "Failed to load results for realization " + f"{realization}, failed with: {err}", + ) + + final_result = parameters_result + if response_result.status != LoadStatus.LOAD_SUCCESSFUL: + final_result = response_result + ensemble.set_failure( + realization, RealizationStorageState.LOAD_FAILURE, final_result.message + ) + elif ensemble.has_failure(realization): + ensemble.unset_failure(realization) + + return final_result diff --git a/src/ert/config/gen_data_config.py b/src/ert/config/gen_data_config.py index 9c9e4093fd1..a24315a4c5e 100644 --- a/src/ert/config/gen_data_config.py +++ b/src/ert/config/gen_data_config.py @@ -108,11 +108,11 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]: ) def read_from_file(self, run_path: str, _: int) -> polars.DataFrame: - def _read_file(filename: Path, report_step: int) -> polars.DataFrame: - if not filename.exists(): - raise ValueError(f"Missing output file: {filename}") - data = np.loadtxt(_run_path / filename, ndmin=1) - active_information_file = _run_path / (str(filename) + "_active") + def _read_file(file_path: Path, report_step: int) -> polars.DataFrame: + if not file_path.exists(): + raise ValueError(f"Missing output file: {file_path}") + data = np.loadtxt(file_path, ndmin=1) + active_information_file = _run_path / (file_path.name + "_active") if active_information_file.exists(): active_list = np.loadtxt(active_information_file) data[active_list == 0] = np.nan diff --git a/src/ert/gui/ertwidgets/__init__.py b/src/ert/gui/ertwidgets/__init__.py index e9d3cda9542..00253078b04 100644 --- a/src/ert/gui/ertwidgets/__init__.py +++ b/src/ert/gui/ertwidgets/__init__.py @@ -25,6 +25,7 @@ def wrapper(*arg: Any) -> Any: from .ensembleselector import EnsembleSelector from .checklist import CheckList from .stringbox import StringBox +from .textbox import TextBox from .listeditbox import ListEditBox from .customdialog import CustomDialog from .pathchooser import PathChooser @@ -57,6 +58,7 @@ def wrapper(*arg: Any) -> Any: "SelectableListModel", "StringBox", "TargetEnsembleModel", + "TextBox", "TextModel", "ValueModel", "showWaitCursorWhileWaiting", diff --git a/src/ert/gui/ertwidgets/textbox.py b/src/ert/gui/ertwidgets/textbox.py new file mode 100644 index 00000000000..190be153502 --- /dev/null +++ b/src/ert/gui/ertwidgets/textbox.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from qtpy.QtGui import QPalette +from qtpy.QtWidgets import QTextEdit + +from .validationsupport import ValidationSupport + +if TYPE_CHECKING: + from ert.validation import StringDefinition + + from .models import TextModel + + +class TextBox(QTextEdit): + """StringBox shows a string. The data structure expected and sent to the + getter and setter is a string.""" + + def __init__( + self, + model: TextModel, + default_string: str = "", + placeholder_text: str = "", + minimum_width: int = 250, + ): + """ + :type model: ert.gui.ertwidgets.models.valuemodel.ValueModel + :type help_link: str + :type default_string: str + """ + QTextEdit.__init__(self) + self.setMinimumWidth(minimum_width) + self._validation = ValidationSupport(self) + self._validator: Optional[StringDefinition] = None + self._model = model + self._enable_validation = True + + if placeholder_text: + self.setPlaceholderText(placeholder_text) + + self.textChanged.connect(self.textBoxChanged) + self.textChanged.connect(self.validateString) + + self._valid_color = self.palette().color(self.backgroundRole()) + self.setText(default_string) + + self._model.valueChanged.connect(self.modelChanged) + self.modelChanged() + + def validateString(self) -> None: + if self._enable_validation: + string_to_validate = self.get_text + if self._validator is not None: + status = self._validator.validate(string_to_validate) + + palette = QPalette() + if not status: + palette.setColor( + self.backgroundRole(), ValidationSupport.ERROR_COLOR + ) + self.setPalette(palette) + self._validation.setValidationMessage( + str(status), ValidationSupport.EXCLAMATION + ) + else: + palette.setColor(self.backgroundRole(), self._valid_color) + self.setPalette(palette) + self._validation.setValidationMessage("") + + def emitChange(self, q_string: Any) -> None: + self.textChanged.emit(str(q_string)) + + def textBoxChanged(self) -> None: + """Called whenever the contents of the textbox changes.""" + text: Optional[str] = self.toPlainText() + if not text: + text = None + + self._model.setValue(text) + + def modelChanged(self) -> None: + """Retrieves data from the model and inserts it into the textbox""" + text = self._model.getValue() + if text is None: + text = "" + # If model and view has same text, return + if text == self.toPlainText(): + return + self.setText(str(text)) + + @property + def model(self) -> TextModel: + return self._model + + def setValidator(self, validator: StringDefinition) -> None: + self._validator = validator + + def getValidationSupport(self) -> ValidationSupport: + return self._validation + + def isValid(self) -> bool: + return self._validation.isValid() + + @property + def get_text(self) -> str: + return self.toPlainText() if self.toPlainText() else self.placeholderText() + + def enable_validation(self, enabled: bool) -> None: + self._enable_validation = enabled diff --git a/src/ert/gui/tools/load_results/load_results_panel.py b/src/ert/gui/tools/load_results/load_results_panel.py index 04e6313acd2..0228ad0b03d 100644 --- a/src/ert/gui/tools/load_results/load_results_panel.py +++ b/src/ert/gui/tools/load_results/load_results_panel.py @@ -1,7 +1,7 @@ from __future__ import annotations from qtpy.QtCore import Qt, Signal -from qtpy.QtWidgets import QFormLayout, QMessageBox, QTextEdit, QWidget +from qtpy.QtWidgets import QFormLayout, QMessageBox, QWidget from ert.gui.ertnotifier import ErtNotifier from ert.gui.ertwidgets import ( @@ -10,11 +10,12 @@ ErtMessageBox, QApplication, StringBox, - ValueModel, + TextBox, + TextModel, ) from ert.libres_facade import LibresFacade from ert.run_models.base_run_model import captured_logs -from ert.validation import IntegerArgument, RangeStringArgument +from ert.validation import RangeStringArgument, StringDefinition class LoadResultsPanel(QWidget): @@ -34,46 +35,33 @@ def __init__(self, facade: LibresFacade, notifier: ErtNotifier): layout = QFormLayout() - run_path_text = QTextEdit() - run_path_text.setText(self.readCurrentRunPath()) - run_path_text.setDisabled(True) - run_path_text.setFixedHeight(80) + self._run_path_text = TextBox(TextModel(self.readCurrentRunPath())) + self._run_path_text.setFixedHeight(80) + self._run_path_text.setValidator(StringDefinition(required=[""])) + self._run_path_text.setObjectName("run_path_edit_lrm") + self._run_path_text.getValidationSupport().validationChanged.connect( + self.panelConfigurationChanged + ) - layout.addRow("Load data from current run path: ", run_path_text) + layout.addRow("Load data from run path: ", self._run_path_text) ensemble_selector = EnsembleSelector(self._notifier) layout.addRow("Load into ensemble:", ensemble_selector) self._ensemble_selector = ensemble_selector - self._active_realizations_model = ActiveRealizationsModel( - self._facade.get_ensemble_size() - ) + ensemble_size = self._facade.get_ensemble_size() + self._active_realizations_model = ActiveRealizationsModel(ensemble_size) self._active_realizations_field = StringBox( self._active_realizations_model, # type: ignore "load_results_manually/Realizations", ) - self._active_realizations_field.setValidator( - RangeStringArgument(self._facade.get_ensemble_size()), - ) + self._active_realizations_field.setValidator(RangeStringArgument(ensemble_size)) self._active_realizations_field.setObjectName("active_realizations_lrm") layout.addRow("Realizations to load:", self._active_realizations_field) - self._iterations_model = ValueModel(0) # type: ignore - self._iterations_field = StringBox( - self._iterations_model, # type: ignore - "load_results_manually/iterations", - ) - self._iterations_field.setValidator(IntegerArgument(from_value=0)) - self._iterations_field.setObjectName("iterations_field_lrm") - layout.addRow("Iteration to load:", self._iterations_field) - self._active_realizations_field.getValidationSupport().validationChanged.connect( self.panelConfigurationChanged ) - self._iterations_field.getValidationSupport().validationChanged.connect( - self.panelConfigurationChanged - ) - self.setLayout(layout) def readCurrentRunPath(self) -> str: @@ -85,36 +73,21 @@ def readCurrentRunPath(self) -> str: def isConfigurationValid(self) -> bool: return ( - self._active_realizations_field.isValid() - and self._iterations_field.isValid() + self._active_realizations_field.isValid() and self._run_path_text.isValid() ) def load(self) -> int: - selected_ensemble = self._notifier.current_ensemble realizations = self._active_realizations_model.getActiveRealizationsMask() - iteration = self._iterations_model.getValue() - try: - if iteration is None: - iteration = "" - iteration_int = int(iteration) - except ValueError: - QMessageBox.warning( - self, - "Warning", - ( - "Expected an integer number in iteration field, " - f'got "{iteration}"' - ), - ) - return False - + active_realizations = [ + iens for iens, active in enumerate(realizations) if active + ] QApplication.setOverrideCursor(Qt.CursorShape.WaitCursor) messages: list[str] = [] with captured_logs(messages): - loaded = self._facade.load_from_forward_model( - selected_ensemble, # type: ignore - realizations, # type: ignore - iteration_int, + loaded = self._facade.load_from_run_path( + run_path_format=self._run_path_text.get_text, + ensemble=self._notifier.current_ensemble, # type: ignore + active_realizations=active_realizations, ) QApplication.restoreOverrideCursor() diff --git a/src/ert/libres_facade.py b/src/ert/libres_facade.py index a3be07222e3..40f9e4577ed 100644 --- a/src/ert/libres_facade.py +++ b/src/ert/libres_facade.py @@ -5,6 +5,7 @@ import time import warnings from multiprocessing.pool import ThreadPool +from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -20,7 +21,7 @@ from pandas import DataFrame from ert.analysis import AnalysisEvent, SmootherSnapshot, smoother_update -from ert.callbacks import forward_model_ok +from ert.callbacks import load_realization from ert.config import ( EnkfObservationImplementationType, ErtConfig, @@ -29,10 +30,8 @@ from ert.data import MeasuredData from ert.data._measured_data import ObservationError, ResponseError from ert.load_status import LoadResult, LoadStatus -from ert.run_arg import create_run_arguments from .plugins import ErtPluginContext -from .runpaths import Runpaths _logger = logging.getLogger(__name__) @@ -43,16 +42,16 @@ EnkfObs, WorkflowJob, ) - from ert.run_arg import RunArg from ert.storage import Ensemble, Storage -def _load_realization( - realisation: int, - run_args: List[RunArg], +def _load_realization_from_run_path( + run_path: str, + realization: int, + ensemble: Ensemble, ) -> Tuple[LoadResult, int]: - result = asyncio.run(forward_model_ok(run_args[realisation])) - return result, realisation + result = asyncio.run(load_realization(run_path, realization, ensemble)) + return result, realization class LibresFacade: @@ -122,34 +121,20 @@ def get_ensemble_size(self) -> int: def run_path(self) -> str: return self.config.model_config.runpath_format_string + @property + def resolved_run_path(self) -> str: + return str(Path(self.config.model_config.runpath_format_string).resolve()) + def load_from_forward_model( self, ensemble: Ensemble, realisations: npt.NDArray[np.bool_], - iteration: Optional[int] = None, ) -> int: - if iteration is not None: - warnings.warn( - "The iteration argument has no effect, iteration is read from ensemble", - DeprecationWarning, - stacklevel=1, - ) t = time.perf_counter() - run_args = create_run_arguments( - Runpaths( - jobname_format=self.config.model_config.jobname_format_string, - runpath_format=self.config.model_config.runpath_format_string, - filename=str(self.config.runpath_file), - substitution_list=self.config.substitution_list, - eclbase=self.config.ensemble_config.eclbase, - ), - realisations, - ensemble=ensemble, - ) - nr_loaded = self._load_from_run_path( - self.config.model_config.num_realizations, - run_args, - realisations, + nr_loaded = self.load_from_run_path( + self.resolved_run_path, + ensemble, + [r for r, active in enumerate(realisations) if active], ) _logger.debug( f"load_from_forward_model() time_used {(time.perf_counter() - t):.4f}s" @@ -157,21 +142,26 @@ def load_from_forward_model( return nr_loaded @staticmethod - def _load_from_run_path( - ensemble_size: int, - run_args: List[RunArg], - active_realizations: npt.NDArray[np.bool_], + def load_from_run_path( + run_path_format: str, + ensemble: Ensemble, + active_realizations: List[int], ) -> int: """Returns the number of loaded realizations""" pool = ThreadPool(processes=8) async_result = [ pool.apply_async( - _load_realization, - (iens, run_args), + _load_realization_from_run_path, + ( + run_path_format.replace("", str(realization)).replace( + "", "0" + ), + realization, + ensemble, + ), ) - for iens in range(ensemble_size) - if active_realizations[iens] + for realization in active_realizations ] loaded = 0 diff --git a/src/ert/validation/__init__.py b/src/ert/validation/__init__.py index c4f5e1a2990..32c13f2c82d 100644 --- a/src/ert/validation/__init__.py +++ b/src/ert/validation/__init__.py @@ -6,6 +6,7 @@ from .proper_name_format_argument import ProperNameFormatArgument from .range_string_argument import RangeStringArgument from .rangestring import mask_to_rangestring, rangestring_to_list, rangestring_to_mask +from .string_definition import StringDefinition from .validation_status import ValidationStatus __all__ = [ @@ -17,6 +18,7 @@ "ProperNameArgument", "ProperNameFormatArgument", "RangeStringArgument", + "StringDefinition", "ValidationStatus", "mask_to_rangestring", "rangestring_to_list", diff --git a/src/ert/validation/string_definition.py b/src/ert/validation/string_definition.py new file mode 100644 index 00000000000..9547fe7dd34 --- /dev/null +++ b/src/ert/validation/string_definition.py @@ -0,0 +1,39 @@ +from typing import List, Optional + +from .validation_status import ValidationStatus + + +class StringDefinition: + MISSING_TOKEN = "Missing required %s!" + INVALID_TOKEN = "Contains invalid string %s!" + + def __init__( + self, + optional: bool = False, + required: Optional[List[str]] = None, + invalid: Optional[List[str]] = None, + ) -> None: + super().__init__() + self.__optional = optional + self._required_tokens = required or [] + self._invalid_tokens = invalid or [] + + def isOptional(self) -> bool: + return self.__optional + + def validate(self, value: str) -> ValidationStatus: + vs = ValidationStatus() + required = [token for token in self._required_tokens if token not in value] + invalid = [token for token in self._invalid_tokens if token in value] + + if not self.isOptional() and any(required): + vs.setFailed() + for token in required: + vs.addToMessage(StringDefinition.MISSING_TOKEN % token) + + if not self.isOptional() and any(invalid): + vs.setFailed() + for token in invalid: + vs.addToMessage(StringDefinition.INVALID_TOKEN % token) + + return vs diff --git a/src/everest/bin/everload_script.py b/src/everest/bin/everload_script.py index d1552d93979..ab1d9bcdad2 100755 --- a/src/everest/bin/everload_script.py +++ b/src/everest/bin/everload_script.py @@ -189,7 +189,7 @@ def _internalize_batch(ert_config, batch_id, batch_data): realizations = [True] * batch_size + [False] * ( facade.get_ensemble_size() - batch_size ) - facade.load_from_forward_model(ensemble, realizations, 0) + facade.load_from_forward_model(ensemble, realizations) if __name__ == "__main__": diff --git a/tests/ert/performance_tests/enkf/test_load_state.py b/tests/ert/performance_tests/enkf/test_load_state.py index e5c481841dc..51442345d5e 100644 --- a/tests/ert/performance_tests/enkf/test_load_state.py +++ b/tests/ert/performance_tests/enkf/test_load_state.py @@ -15,7 +15,7 @@ def test_load_from_context(benchmark, template_config): expected_reals = template_config["reals"] realisations = [True] * expected_reals loaded_reals = benchmark( - facade.load_from_forward_model, load_into, realisations, 0 + facade.load_from_forward_model, load_into, realisations ) assert loaded_reals == expected_reals @@ -30,6 +30,6 @@ def test_load_from_fs(benchmark, template_config): expected_reals = template_config["reals"] realisations = [True] * expected_reals loaded_reals = benchmark( - facade.load_from_forward_model, load_from, realisations, 0 + facade.load_from_forward_model, load_from, realisations ) assert loaded_reals == expected_reals diff --git a/tests/ert/ui_tests/cli/test_parameter_sample_types.py b/tests/ert/ui_tests/cli/test_parameter_sample_types.py index 4ffef35f05f..2e921e421d6 100644 --- a/tests/ert/ui_tests/cli/test_parameter_sample_types.py +++ b/tests/ert/ui_tests/cli/test_parameter_sample_types.py @@ -18,7 +18,7 @@ def load_from_forward_model(ert_config, ensemble): facade = LibresFacade.from_config_file(ert_config) realizations = [True] * facade.get_ensemble_size() - return facade.load_from_forward_model(ensemble, realizations, 0) + return facade.load_from_forward_model(ensemble, realizations) @pytest.mark.usefixtures("set_site_config") diff --git a/tests/ert/ui_tests/gui/test_load_results_manually.py b/tests/ert/ui_tests/gui/test_load_results_manually.py index 6dcc06cd688..9ddad1c3f42 100644 --- a/tests/ert/ui_tests/gui/test_load_results_manually.py +++ b/tests/ert/ui_tests/gui/test_load_results_manually.py @@ -1,7 +1,7 @@ from qtpy.QtCore import Qt, QTimer from qtpy.QtWidgets import QPushButton -from ert.gui.ertwidgets import ClosableDialog, StringBox +from ert.gui.ertwidgets import ClosableDialog, StringBox, TextBox from ert.gui.ertwidgets.ensembleselector import EnsembleSelector from ert.gui.tools.load_results import LoadResultsPanel @@ -25,6 +25,11 @@ def handle_load_results_dialog(): load_button = get_child(panel.parent(), QPushButton, name="Load") + run_path_edit = get_child(panel, TextBox, name="run_path_edit_lrm") + assert run_path_edit.isEnabled() + valid_text = run_path_edit.get_text + assert "" in valid_text + active_realizations = get_child( panel, StringBox, name="active_realizations_lrm" ) @@ -37,12 +42,9 @@ def handle_load_results_dialog(): active_realizations.setText(default_value_active_reals) assert load_button.isEnabled() - iterations_field = get_child(panel, StringBox, name="iterations_field_lrm") - default_value_iteration = iterations_field.get_text - iterations_field.setText("-10") - + run_path_edit.setText(valid_text.replace("", "")) assert not load_button.isEnabled() - iterations_field.setText(default_value_iteration) + run_path_edit.setText(valid_text) assert load_button.isEnabled() dialog.close() diff --git a/tests/ert/unit_tests/scenarios/test_summary_response.py b/tests/ert/unit_tests/scenarios/test_summary_response.py index bb07dddaf74..0c93300e9bc 100644 --- a/tests/ert/unit_tests/scenarios/test_summary_response.py +++ b/tests/ert/unit_tests/scenarios/test_summary_response.py @@ -76,9 +76,7 @@ def create_responses(config_file, prior_ensemble, response_times): run_sim(response_time, rng.standard_normal(), fname=f"ECLIPSE_CASE_{i}") os.chdir(cwd) facade = LibresFacade.from_config_file(config_file) - facade.load_from_forward_model( - prior_ensemble, [True] * facade.get_ensemble_size(), 0 - ) + facade.load_from_forward_model(prior_ensemble, [True] * facade.get_ensemble_size()) def test_that_reading_matching_time_is_ok(ert_config, storage, prior_ensemble): diff --git a/tests/ert/unit_tests/storage/create_runpath.py b/tests/ert/unit_tests/storage/create_runpath.py index 860d377097f..9c29368712f 100644 --- a/tests/ert/unit_tests/storage/create_runpath.py +++ b/tests/ert/unit_tests/storage/create_runpath.py @@ -56,4 +56,4 @@ def create_runpath( def load_from_forward_model(ert_config, ensemble): facade = LibresFacade.from_config_file(ert_config) realizations = [True] * facade.get_ensemble_size() - return facade.load_from_forward_model(ensemble, realizations, 0) + return facade.load_from_forward_model(ensemble, realizations) diff --git a/tests/ert/unit_tests/test_load_forward_model.py b/tests/ert/unit_tests/test_load_forward_model.py index ada18db7e77..55575c8a56d 100644 --- a/tests/ert/unit_tests/test_load_forward_model.py +++ b/tests/ert/unit_tests/test_load_forward_model.py @@ -11,6 +11,7 @@ from ert.config import ErtConfig from ert.enkf_main import create_run_path from ert.libres_facade import LibresFacade +from ert.run_arg import create_run_arguments from ert.storage import open_storage @@ -77,7 +78,7 @@ def test_load_forward_model(snake_oil_default_storage): experiment = storage.get_experiment_by_name("ensemble-experiment") default = experiment.get_ensemble_by_name("default_0") - loaded = facade.load_from_forward_model(default, realizations, 0) + loaded = facade.load_from_forward_model(default, realizations) assert loaded == 1 assert default.get_realization_mask_with_responses()[ realisation_number @@ -141,7 +142,7 @@ def test_load_forward_model_summary( ) facade = LibresFacade(ert_config) with caplog.at_level(logging.ERROR): - loaded = facade.load_from_forward_model(prior_ensemble, [True], 0) + loaded = facade.load_from_forward_model(prior_ensemble, [True]) expected_loaded, expected_log_message = expected assert loaded == expected_loaded if expected_log_message: @@ -166,7 +167,7 @@ def test_load_forward_model_gen_data(setup_case): fout.write("\n".join(["1", "0", "1"])) facade = LibresFacade(config) - facade.load_from_forward_model(prior_ensemble, [True], 0) + facade.load_from_forward_model(prior_ensemble, [True]) df = prior_ensemble.load_responses("gen_data", (0,)) filter_cond = polars.col("report_step").eq(0), polars.col("values").is_not_nan() assert df.filter(filter_cond)["values"].to_list() == [1.0, 3.0] @@ -188,7 +189,7 @@ def test_single_valued_gen_data_with_active_info_is_loaded(setup_case): fout.write("\n".join(["1"])) facade = LibresFacade(config) - facade.load_from_forward_model(prior_ensemble, [True], 0) + facade.load_from_forward_model(prior_ensemble, [True]) df = prior_ensemble.load_responses("RESPONSE", (0,)) assert df["values"].to_list() == [1.0] @@ -209,7 +210,7 @@ def test_that_all_deactivated_values_are_loaded(setup_case): fout.write("\n".join(["0"])) facade = LibresFacade(config) - facade.load_from_forward_model(prior_ensemble, [True], 0) + facade.load_from_forward_model(prior_ensemble, [True]) response = prior_ensemble.load_responses("RESPONSE", (0,)) assert np.isnan(response[0]["values"].to_list()) assert len(response) == 1 @@ -247,7 +248,7 @@ def test_loading_gen_data_without_restart(storage, run_paths, run_args): fout.write("\n".join(["1", "0", "1"])) facade = LibresFacade.from_config_file("config.ert") - facade.load_from_forward_model(prior_ensemble, [True], 0) + facade.load_from_forward_model(prior_ensemble, [True]) df = prior_ensemble.load_responses("RESPONSE", (0,)) df_no_nans = df.filter(polars.col("values").is_not_nan()) assert df_no_nans["values"].to_list() == [1.0, 3.0] @@ -269,6 +270,58 @@ def test_that_the_states_are_set_correctly(): new_ensemble = storage.create_ensemble( experiment=ensemble.experiment, ensemble_size=ensemble_size ) - facade.load_from_forward_model(new_ensemble, realizations, 0) + facade.load_from_forward_model(new_ensemble, realizations) assert not new_ensemble.is_initalized() assert new_ensemble.has_data() + + +@pytest.mark.parametrize("iter", [None, 0, 1, 2, 3]) +@pytest.mark.usefixtures("use_tmpdir") +def test_loading_from_any_available_iter(storage, run_paths, run_args, iter): + config_text = dedent( + """ + NUM_REALIZATIONS 1 + GEN_DATA RESPONSE RESULT_FILE:response.out INPUT_FORMAT:ASCII + """ + ) + Path("config.ert").write_text(config_text, encoding="utf-8") + + ert_config = ErtConfig.from_file("config.ert") + prior_ensemble = storage.create_ensemble( + storage.create_experiment( + responses=ert_config.ensemble_config.response_configuration + ), + name="prior", + ensemble_size=ert_config.model_config.num_realizations, + iteration=iter if iter is not None else 0, + ) + + run_args = create_run_arguments( + run_paths(ert_config), + [True] * ert_config.model_config.num_realizations, + prior_ensemble, + ) + create_run_path( + run_args, + prior_ensemble, + ert_config, + run_paths(ert_config), + ) + run_path = Path( + f"simulations/realization-0/iter-{iter if iter is not None else 0}/" + ) + with open(run_path / "response.out", "w", encoding="utf-8") as fout: + fout.write("\n".join(["1", "2", "3"])) + with open(run_path / "response.out_active", "w", encoding="utf-8") as fout: + fout.write("\n".join(["1", "0", "1"])) + + facade = LibresFacade.from_config_file("config.ert") + run_path_format = str( + Path( + f"simulations/realization-/iter-{iter if iter is not None else 0}" + ).resolve() + ) + facade.load_from_run_path(run_path_format, prior_ensemble, [0]) + df = prior_ensemble.load_responses("RESPONSE", (0,)) + df_no_nans = df.filter(polars.col("values").is_not_nan()) + assert df_no_nans["values"].to_list() == [1.0, 3.0] diff --git a/tests/ert/unit_tests/test_summary_response.py b/tests/ert/unit_tests/test_summary_response.py index 8f57fa0d231..a870f1490ad 100644 --- a/tests/ert/unit_tests/test_summary_response.py +++ b/tests/ert/unit_tests/test_summary_response.py @@ -49,7 +49,7 @@ def test_load_summary_response_restart_not_zero( shutil.copy(test_path / "PRED_RUN.UNSMRY", sim_path / "PRED_RUN.UNSMRY") facade = LibresFacade.from_config_file("config.ert") - facade.load_from_forward_model(ensemble, [True], 0) + facade.load_from_forward_model(ensemble, [True]) df = ensemble.load_responses("summary", (0,)) df = df.pivot(on="response_key", values="values")