From 2a76405759d3ce50dd265bc68f2e39b1bd7c45da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Wed, 19 Jun 2024 12:29:58 +0200 Subject: [PATCH 1/2] Fix bug where there were false positives in runpath check --- src/ert/gui/simulation/evaluate_ensemble_panel.py | 5 ++++- src/ert/run_models/base_run_model.py | 2 +- src/ert/run_models/model_factory.py | 5 +++++ src/ert/run_models/run_arguments.py | 1 + 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/ert/gui/simulation/evaluate_ensemble_panel.py b/src/ert/gui/simulation/evaluate_ensemble_panel.py index a6df5133a82..e518ff5f7fb 100644 --- a/src/ert/gui/simulation/evaluate_ensemble_panel.py +++ b/src/ert/gui/simulation/evaluate_ensemble_panel.py @@ -20,6 +20,7 @@ class Arguments: mode: str realizations: str ensemble_id: str + start_iteration: int class EvaluateEnsemblePanel(ExperimentConfigPanel): @@ -70,10 +71,12 @@ def isConfigurationValid(self) -> bool: ) def get_experiment_arguments(self) -> Arguments: + ensemble_id = self._ensemble_selector.selected_ensemble.id return Arguments( mode=EVALUATE_ENSEMBLE_MODE, - ensemble_id=str(self._ensemble_selector.selected_ensemble.id), + ensemble_id=str(ensemble_id), realizations=self._active_realizations_field.text(), + start_iteration=self.notifier.storage.get_ensemble(ensemble_id).iteration, ) def _realizations_from_fs(self) -> None: diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 111a7ed1685..0e08aba1d94 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -607,7 +607,7 @@ def paths(self) -> List[str]: number_of_iterations = self._simulation_arguments.num_iterations active_mask = self._simulation_arguments.active_realizations active_realizations = [i for i in range(len(active_mask)) if active_mask[i]] - for iteration in range(start_iteration, number_of_iterations): + for iteration in range(start_iteration, start_iteration + number_of_iterations): run_paths.extend(self.run_paths.get_paths(active_realizations, iteration)) return run_paths diff --git a/src/ert/run_models/model_factory.py b/src/ert/run_models/model_factory.py index 74f6c598ed1..13d839327c0 100644 --- a/src/ert/run_models/model_factory.py +++ b/src/ert/run_models/model_factory.py @@ -3,6 +3,7 @@ import logging from queue import SimpleQueue from typing import TYPE_CHECKING, Tuple +from uuid import UUID import numpy as np @@ -165,6 +166,7 @@ def _setup_evaluate_ensemble( ensemble_size=config.model_config.num_realizations, stop_long_running=config.analysis_config.stop_long_running, experiment_name=None, + start_iteration=args.start_iteration, ), config, storage, @@ -244,6 +246,9 @@ def _setup_multiple_data_assimilation( ensemble_size=config.model_config.num_realizations, stop_long_running=config.analysis_config.stop_long_running, experiment_name=args.experiment_name, + start_iteration=storage.get_ensemble(UUID(prior_ensemble)).iteration + 1 + if restart_run + else 0, ), config, storage, diff --git a/src/ert/run_models/run_arguments.py b/src/ert/run_models/run_arguments.py index 3797addd42d..8dc4d00b298 100644 --- a/src/ert/run_models/run_arguments.py +++ b/src/ert/run_models/run_arguments.py @@ -42,6 +42,7 @@ def __post_init__(self) -> None: class EvaluateEnsembleRunArguments(SimulationArguments): active_realizations: List[bool] current_ensemble: str + start_iteration: int ensemble_type: str = "Evaluate ensemble" def __post_init__(self) -> None: From 495e02a9121ebda02c6678a9578abd8af0c83b3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Thu, 13 Jun 2024 10:05:07 +0200 Subject: [PATCH 2/2] Simplify weights --- src/ert/cli/main.py | 2 +- .../multiple_data_assimilation_panel.py | 20 +++--- .../run_models/multiple_data_assimilation.py | 62 ++++++++----------- tests/integration_tests/test_cli.py | 5 +- .../test_multiple_data_assimilation.py | 36 +++++------ 5 files changed, 51 insertions(+), 74 deletions(-) diff --git a/src/ert/cli/main.py b/src/ert/cli/main.py index 55c949702dc..028b7296273 100644 --- a/src/ert/cli/main.py +++ b/src/ert/cli/main.py @@ -94,7 +94,7 @@ def run_cli(args: Namespace, plugin_manager: Optional[ErtPluginManager] = None) status_queue, ) except ValueError as e: - raise ErtCliError(e) from e + raise ErtCliError(f"{args.mode} was not valid, failed with: {e}") from e if args.port_range is None and model.queue_system == QueueSystem.LOCAL: args.port_range = range(49152, 51819) diff --git a/src/ert/gui/simulation/multiple_data_assimilation_panel.py b/src/ert/gui/simulation/multiple_data_assimilation_panel.py index 40e39fca215..24e32f2c40f 100644 --- a/src/ert/gui/simulation/multiple_data_assimilation_panel.py +++ b/src/ert/gui/simulation/multiple_data_assimilation_panel.py @@ -158,18 +158,16 @@ def updateVisualizationOfNormalizedWeights() -> None: self.weights_valid = False if self._relative_iteration_weights_box.isValid(): - weights = MultipleDataAssimilation.parseWeights( - relative_iteration_weights_model.getValue() - ) - normalized_weights = MultipleDataAssimilation.normalizeWeights(weights) - normalized_weights_model.setValue( - ", ".join(f"{x:.2f}" for x in normalized_weights) - ) - - if not weights: - normalized_weights_model.setValue("The weights are invalid!") - else: + try: + normalized_weights = MultipleDataAssimilation.parse_weights( + relative_iteration_weights_model.getValue() + ) + normalized_weights_model.setValue( + ", ".join(f"{x:.2f}" for x in normalized_weights) + ) self.weights_valid = True + except ValueError: + normalized_weights_model.setValue("The weights are invalid!") else: normalized_weights_model.setValue("The weights are invalid!") diff --git a/src/ert/run_models/multiple_data_assimilation.py b/src/ert/run_models/multiple_data_assimilation.py index e237935ee77..70f8a9777ab 100644 --- a/src/ert/run_models/multiple_data_assimilation.py +++ b/src/ert/run_models/multiple_data_assimilation.py @@ -45,6 +45,9 @@ def __init__( update_settings: UpdateSettings, status_queue: SimpleQueue[StatusEvents], ): + self.weights = self.parse_weights(simulation_arguments.weights) + self.es_settings = es_settings + self.update_settings = update_settings super().__init__( simulation_arguments, config, @@ -53,9 +56,13 @@ def __init__( status_queue, phase_count=2, ) - self.weights = MultipleDataAssimilation.default_weights - self.es_settings = es_settings - self.update_settings = update_settings + if simulation_arguments.start_iteration == 0: + # If a regular run we also need to account for the prior + self.simulation_arguments.num_iterations = len(self.weights) + 1 + else: + self.simulation_arguments.num_iterations = len( + self.weights[simulation_arguments.start_iteration - 1 :] + ) def run_experiment( self, evaluator_server_config: EvaluatorServerConfig @@ -64,21 +71,10 @@ def run_experiment( self._simulation_arguments.active_realizations.count(True), self._simulation_arguments.minimum_required_realizations, ) - weights = self.parseWeights(self._simulation_arguments.weights) + iteration_count = self.simulation_arguments.num_iterations + self.setPhaseCount(iteration_count) - if not weights: - raise ErtRunError( - "Operation halted: ES-MDA requires weights to proceed. " - "Please provide appropriate weights and try again." - ) - - iteration_count = len(weights) - - weights = self.normalizeWeights(weights) - - self.setPhaseCount(iteration_count + 1) - - log_msg = f"Running ES-MDA with normalized weights {weights}" + log_msg = f"Running ES-MDA with normalized weights {self.weights}" logger.info(log_msg) self.setPhaseName(log_msg) @@ -137,7 +133,7 @@ def run_experiment( random_seed=self.random_seed, ) self._evaluate_and_postprocess(prior_context, evaluator_server_config) - enumerated_weights = list(enumerate(weights)) + enumerated_weights = list(enumerate(self.weights)) weights_to_run = enumerated_weights[prior.iteration :] for iteration, weight in weights_to_run: @@ -191,7 +187,7 @@ def run_experiment( self.setPhaseName("Post processing...") - self.setPhase(iteration_count + 1, "Experiment completed.") + self.setPhase(iteration_count, "Experiment completed.") return prior_context @@ -228,28 +224,19 @@ def update( ) from e @staticmethod - def normalizeWeights(weights: List[float]) -> List[float]: - """Scale weights such that their reciprocals sum to 1.0, - i.e., sum(1.0 / x for x in weights) == 1.0. - See for example Equation 38 of evensen2018 - Analysis of iterative - ensemble smoothers for solving inverse problems. + def parse_weights(weights: str) -> List[float]: + """Parse weights string and scale weights such that their reciprocals sum + to 1.0, i.e., sum(1.0 / x for x in weights) == 1.0. See for example Equation + 38 of evensen2018 - Analysis of iterative ensemble + smoothers for solving inverse problems. """ if not weights: - return [] - weights = [weight for weight in weights if abs(weight) != 0.0] - - length = sum(1.0 / x for x in weights) - return [x * length for x in weights] - - @staticmethod - def parseWeights(weights: str) -> List[float]: - if not weights: - return [] + raise ValueError(f"Must provide weights, got {weights}") elements = weights.split(",") elements = [element.strip() for element in elements if element.strip()] - result = [] + result: List[float] = [] for element in elements: try: f = float(element) @@ -259,8 +246,11 @@ def parseWeights(weights: str) -> List[float]: result.append(f) except ValueError as e: raise ValueError(f"Warning: cannot parse weight {element}") from e + if not result: + raise ValueError(f"Invalid weights: {weights}") - return result + length = sum(1.0 / x for x in result) + return [x * length for x in result] @classmethod def name(cls) -> str: diff --git a/tests/integration_tests/test_cli.py b/tests/integration_tests/test_cli.py index 547cdcd2bf1..719bd2c9f80 100644 --- a/tests/integration_tests/test_cli.py +++ b/tests/integration_tests/test_cli.py @@ -137,10 +137,7 @@ def test_that_the_cli_raises_exceptions_when_parameters_are_missing(mode): def test_that_the_cli_raises_exceptions_when_no_weight_provided_for_es_mda(): with pytest.raises( ErtCliError, - match=( - "Operation halted: ES-MDA requires weights to proceed. " - "Please provide appropriate weights and try again." - ), + match="Invalid weights: 0", ): run_cli( ES_MDA_MODE, diff --git a/tests/unit_tests/run_models/test_multiple_data_assimilation.py b/tests/unit_tests/run_models/test_multiple_data_assimilation.py index d3a82055c03..c3cb182ee6d 100644 --- a/tests/unit_tests/run_models/test_multiple_data_assimilation.py +++ b/tests/unit_tests/run_models/test_multiple_data_assimilation.py @@ -4,29 +4,21 @@ from ert.run_models import MultipleDataAssimilation as mda -def test_normalized_weights(): - weights = mda.normalizeWeights([1]) - assert weights == [1.0] - - weights = mda.normalizeWeights([1, 1]) - assert weights == [2.0, 2.0] - - weights = np.array(mda.normalizeWeights([8, 4, 2, 1])) +@pytest.mark.parametrize( + "weights, expected", + [ + ("2, 2, 2, 2", [4] * 4), + ("1, 2, 4, ", [1.75, 3.5, 7.0]), + ("1, 0, 1, ", [2, 2]), + ("1.414213562373095, 1.414213562373095", [2, 2]), + ], +) +def test_weights(weights, expected): + weights = mda.parse_weights(weights) + assert weights == expected assert np.reciprocal(weights).sum() == 1.0 -def test_weights(): - weights = mda.parseWeights("2, 2, 2, 2") - assert weights == [2, 2, 2, 2] - - weights = mda.parseWeights("1, 2, 3, ") - assert weights == [1, 2, 3] - - weights = mda.parseWeights("1, 0, 1") - assert weights == [1, 1] - - weights = mda.parseWeights("1.414213562373095, 1.414213562373095") - assert weights == [1.414213562373095, 1.414213562373095] - +def test_invalid_weights(): with pytest.raises(ValueError): - mda.parseWeights("2, error, 2, 2") + mda.parse_weights("2, error, 2, 2")