From b087917b498ff2884d9d6eee71b0140d6db553cf Mon Sep 17 00:00:00 2001 From: Frode Aarstad Date: Thu, 9 Jan 2025 08:51:45 +0100 Subject: [PATCH] Refactor update workflow signature --- src/ert/analysis/_es_update.py | 25 ++++++++++--------- src/ert/run_models/base_run_model.py | 14 ++++++----- .../performance_tests/test_memory_usage.py | 5 +++- .../test_obs_and_responses_performance.py | 7 +++++- .../scenarios/test_summary_response.py | 11 +++++++- .../storage/test_storage_migration.py | 5 +++- 6 files changed, 45 insertions(+), 22 deletions(-) diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 54d0b5e6213..44a10ea63e8 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -22,7 +22,7 @@ from ert.config import GenKwConfig from ..config.analysis_config import ObservationGroups, UpdateSettings -from ..config.analysis_module import ESSettings, IESSettings +from ..config.analysis_module import BaseSettings, ESSettings, IESSettings from . import misfit_preprocessor from .event import ( AnalysisCompleteEvent, @@ -749,14 +749,14 @@ def analysis_IES( def _create_smoother_snapshot( prior_name: str, posterior_name: str, - analysis_config: UpdateSettings, + update_settings: UpdateSettings, global_scaling: float, ) -> SmootherSnapshot: return SmootherSnapshot( source_ensemble_name=prior_name, target_ensemble_name=posterior_name, - alpha=analysis_config.alpha, - std_cutoff=analysis_config.std_cutoff, + alpha=update_settings.alpha, + std_cutoff=update_settings.std_cutoff, global_scaling=global_scaling, update_step_snapshots=[], ) @@ -767,8 +767,8 @@ def smoother_update( posterior_storage: Ensemble, observations: Iterable[str], parameters: Iterable[str], - analysis_config: UpdateSettings | None = None, - es_settings: ESSettings | None = None, + update_settings: UpdateSettings, + es_settings: BaseSettings, rng: np.random.Generator | None = None, progress_callback: Callable[[AnalysisEvent], None] | None = None, global_scaling: float = 1.0, @@ -777,14 +777,15 @@ def smoother_update( progress_callback = noop_progress_callback if rng is None: rng = np.random.default_rng() - analysis_config = UpdateSettings() if analysis_config is None else analysis_config - es_settings = ESSettings() if es_settings is None else es_settings + + assert isinstance(es_settings, ESSettings) + ens_mask = prior_storage.get_realization_mask_with_responses() smoother_snapshot = _create_smoother_snapshot( prior_storage.name, posterior_storage.name, - analysis_config, + update_settings, global_scaling, ) @@ -794,15 +795,15 @@ def smoother_update( observations, rng, es_settings, - analysis_config.alpha, - analysis_config.std_cutoff, + update_settings.alpha, + update_settings.std_cutoff, global_scaling, smoother_snapshot, ens_mask, prior_storage, posterior_storage, progress_callback, - analysis_config.auto_scale_observations, + update_settings.auto_scale_observations, ) except Exception as e: progress_callback( diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index c5564e560f3..09ca0b9557a 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -31,7 +31,8 @@ AnalysisDataEvent, AnalysisErrorEvent, ) -from ert.config import ErtConfig, ESSettings, HookRuntime, QueueSystem +from ert.config import ErtConfig, HookRuntime, QueueSystem +from ert.config.analysis_module import BaseSettings from ert.enkf_main import _seed_sequence, create_run_path from ert.ensemble_evaluator import Ensemble as EEEnsemble from ert.ensemble_evaluator import ( @@ -732,7 +733,7 @@ def _evaluate_and_postprocess( class UpdateRunModel(BaseRunModel): def __init__( self, - es_settings: ESSettings, + analysis_settings: BaseSettings, update_settings: UpdateSettings, config: ErtConfig, storage: Storage, @@ -744,8 +745,9 @@ def __init__( random_seed: int | None, minimum_required_realizations: int, ): - self.es_settings = es_settings - self.update_settings = update_settings + self._analysis_settings: BaseSettings = analysis_settings + self._update_settings: UpdateSettings = update_settings + super().__init__( config, storage, @@ -786,8 +788,8 @@ def update( smoother_update( prior, posterior, - analysis_config=self.update_settings, - es_settings=self.es_settings, + update_settings=self._update_settings, + es_settings=self._analysis_settings, parameters=prior.experiment.update_parameters, observations=prior.experiment.observation_keys, global_scaling=weight, diff --git a/tests/ert/performance_tests/test_memory_usage.py b/tests/ert/performance_tests/test_memory_usage.py index d1e0bb24ad0..abbe8d98ccf 100644 --- a/tests/ert/performance_tests/test_memory_usage.py +++ b/tests/ert/performance_tests/test_memory_usage.py @@ -13,7 +13,8 @@ import xtgeo from ert.analysis import smoother_update -from ert.config import ErtConfig +from ert.config import ErtConfig, ESSettings +from ert.config.analysis_config import UpdateSettings from ert.enkf_main import sample_prior from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE from ert.storage import open_storage @@ -65,6 +66,8 @@ def test_memory_smoothing(poly_template): posterior_ens, list(experiment.observation_keys), list(ert_config.ensemble_config.parameters), + UpdateSettings(), + ESSettings(), ) stats = memray._memray.compute_statistics(str(poly_template / "memray.bin")) diff --git a/tests/ert/performance_tests/test_obs_and_responses_performance.py b/tests/ert/performance_tests/test_obs_and_responses_performance.py index 8c70728b415..0e28f4ba42e 100644 --- a/tests/ert/performance_tests/test_obs_and_responses_performance.py +++ b/tests/ert/performance_tests/test_obs_and_responses_performance.py @@ -8,7 +8,8 @@ import pytest from ert.analysis import smoother_update -from ert.config import GenDataConfig, GenKwConfig, SummaryConfig +from ert.config import ESSettings, GenDataConfig, GenKwConfig, SummaryConfig +from ert.config.analysis_config import UpdateSettings from ert.config.gen_kw_config import TransformFunctionDefinition from ert.enkf_main import sample_prior from ert.storage import open_storage @@ -506,6 +507,8 @@ def test_memory_performance_of_doing_es_update(setup_es_benchmark, tmp_path): posterior, prior.experiment.observation_keys, [gen_kw_name], + UpdateSettings(), + ESSettings(), ) stats = memray._memray.compute_statistics(str(tmp_path / "memray.bin")) @@ -525,6 +528,8 @@ def run(): posterior, prior.experiment.observation_keys, [gen_kw_name], + UpdateSettings(), + ESSettings(), ) benchmark(run) diff --git a/tests/ert/unit_tests/scenarios/test_summary_response.py b/tests/ert/unit_tests/scenarios/test_summary_response.py index 7e4102beff9..0cf556272bb 100644 --- a/tests/ert/unit_tests/scenarios/test_summary_response.py +++ b/tests/ert/unit_tests/scenarios/test_summary_response.py @@ -11,7 +11,8 @@ from ert import LibresFacade from ert.analysis import ErtAnalysisError, smoother_update -from ert.config import ErtConfig +from ert.config import ErtConfig, ESSettings +from ert.config.analysis_config import UpdateSettings from ert.data import MeasuredData from ert.enkf_main import sample_prior @@ -102,6 +103,8 @@ def test_that_reading_matching_time_is_ok(ert_config, storage, prior_ensemble): target_ensemble, prior_ensemble.experiment.observation_keys, ert_config.ensemble_config.parameters, + UpdateSettings(), + ESSettings(), ) @@ -129,6 +132,8 @@ def test_that_mismatched_responses_give_error(ert_config, storage, prior_ensembl target_ensemble, prior_ensemble.experiment.observation_keys, ert_config.ensemble_config.parameters, + UpdateSettings(), + ESSettings(), ) @@ -160,6 +165,8 @@ def test_that_different_length_is_ok_as_long_as_observation_time_exists( target_ensemble, prior_ensemble.experiment.observation_keys, ert_config.ensemble_config.parameters, + UpdateSettings(), + ESSettings(), ) @@ -206,6 +213,8 @@ def test_that_duplicate_summary_time_steps_does_not_fail( target_ensemble, prior_ensemble.experiment.observation_keys, ert_config.ensemble_config.parameters, + UpdateSettings(), + ESSettings(), ) diff --git a/tests/ert/unit_tests/storage/test_storage_migration.py b/tests/ert/unit_tests/storage/test_storage_migration.py index 952e6b16055..1721b954a55 100644 --- a/tests/ert/unit_tests/storage/test_storage_migration.py +++ b/tests/ert/unit_tests/storage/test_storage_migration.py @@ -10,7 +10,8 @@ from packaging import version from ert.analysis import ErtAnalysisError, smoother_update -from ert.config import ErtConfig +from ert.config import ErtConfig, ESSettings +from ert.config.analysis_config import UpdateSettings from ert.storage import open_storage from ert.storage.local_storage import ( _LOCAL_STORAGE_VERSION, @@ -467,6 +468,8 @@ def test_that_manual_update_from_migrated_storage_works( posterior_ens, list(experiment.observation_keys), list(ert_config.ensemble_config.parameters), + UpdateSettings(), + ESSettings(), )