Skip to content

Commit

Permalink
Refactor update workflow signature
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Jan 13, 2025
1 parent 5a0ff5f commit b087917
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 22 deletions.
25 changes: 13 additions & 12 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ert.config import GenKwConfig

from ..config.analysis_config import ObservationGroups, UpdateSettings
from ..config.analysis_module import ESSettings, IESSettings
from ..config.analysis_module import BaseSettings, ESSettings, IESSettings
from . import misfit_preprocessor
from .event import (
AnalysisCompleteEvent,
Expand Down Expand Up @@ -749,14 +749,14 @@ def analysis_IES(
def _create_smoother_snapshot(
prior_name: str,
posterior_name: str,
analysis_config: UpdateSettings,
update_settings: UpdateSettings,
global_scaling: float,
) -> SmootherSnapshot:
return SmootherSnapshot(
source_ensemble_name=prior_name,
target_ensemble_name=posterior_name,
alpha=analysis_config.alpha,
std_cutoff=analysis_config.std_cutoff,
alpha=update_settings.alpha,
std_cutoff=update_settings.std_cutoff,
global_scaling=global_scaling,
update_step_snapshots=[],
)
Expand All @@ -767,8 +767,8 @@ def smoother_update(
posterior_storage: Ensemble,
observations: Iterable[str],
parameters: Iterable[str],
analysis_config: UpdateSettings | None = None,
es_settings: ESSettings | None = None,
update_settings: UpdateSettings,
es_settings: BaseSettings,
rng: np.random.Generator | None = None,
progress_callback: Callable[[AnalysisEvent], None] | None = None,
global_scaling: float = 1.0,
Expand All @@ -777,14 +777,15 @@ def smoother_update(
progress_callback = noop_progress_callback
if rng is None:
rng = np.random.default_rng()
analysis_config = UpdateSettings() if analysis_config is None else analysis_config
es_settings = ESSettings() if es_settings is None else es_settings

assert isinstance(es_settings, ESSettings)

ens_mask = prior_storage.get_realization_mask_with_responses()

smoother_snapshot = _create_smoother_snapshot(
prior_storage.name,
posterior_storage.name,
analysis_config,
update_settings,
global_scaling,
)

Expand All @@ -794,15 +795,15 @@ def smoother_update(
observations,
rng,
es_settings,
analysis_config.alpha,
analysis_config.std_cutoff,
update_settings.alpha,
update_settings.std_cutoff,
global_scaling,
smoother_snapshot,
ens_mask,
prior_storage,
posterior_storage,
progress_callback,
analysis_config.auto_scale_observations,
update_settings.auto_scale_observations,
)
except Exception as e:
progress_callback(
Expand Down
14 changes: 8 additions & 6 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
AnalysisDataEvent,
AnalysisErrorEvent,
)
from ert.config import ErtConfig, ESSettings, HookRuntime, QueueSystem
from ert.config import ErtConfig, HookRuntime, QueueSystem
from ert.config.analysis_module import BaseSettings
from ert.enkf_main import _seed_sequence, create_run_path
from ert.ensemble_evaluator import Ensemble as EEEnsemble
from ert.ensemble_evaluator import (
Expand Down Expand Up @@ -732,7 +733,7 @@ def _evaluate_and_postprocess(
class UpdateRunModel(BaseRunModel):
def __init__(
self,
es_settings: ESSettings,
analysis_settings: BaseSettings,
update_settings: UpdateSettings,
config: ErtConfig,
storage: Storage,
Expand All @@ -744,8 +745,9 @@ def __init__(
random_seed: int | None,
minimum_required_realizations: int,
):
self.es_settings = es_settings
self.update_settings = update_settings
self._analysis_settings: BaseSettings = analysis_settings
self._update_settings: UpdateSettings = update_settings

super().__init__(
config,
storage,
Expand Down Expand Up @@ -786,8 +788,8 @@ def update(
smoother_update(
prior,
posterior,
analysis_config=self.update_settings,
es_settings=self.es_settings,
update_settings=self._update_settings,
es_settings=self._analysis_settings,
parameters=prior.experiment.update_parameters,
observations=prior.experiment.observation_keys,
global_scaling=weight,
Expand Down
5 changes: 4 additions & 1 deletion tests/ert/performance_tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import xtgeo

from ert.analysis import smoother_update
from ert.config import ErtConfig
from ert.config import ErtConfig, ESSettings
from ert.config.analysis_config import UpdateSettings
from ert.enkf_main import sample_prior
from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE
from ert.storage import open_storage
Expand Down Expand Up @@ -65,6 +66,8 @@ def test_memory_smoothing(poly_template):
posterior_ens,
list(experiment.observation_keys),
list(ert_config.ensemble_config.parameters),
UpdateSettings(),
ESSettings(),
)

stats = memray._memray.compute_statistics(str(poly_template / "memray.bin"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pytest

from ert.analysis import smoother_update
from ert.config import GenDataConfig, GenKwConfig, SummaryConfig
from ert.config import ESSettings, GenDataConfig, GenKwConfig, SummaryConfig
from ert.config.analysis_config import UpdateSettings
from ert.config.gen_kw_config import TransformFunctionDefinition
from ert.enkf_main import sample_prior
from ert.storage import open_storage
Expand Down Expand Up @@ -506,6 +507,8 @@ def test_memory_performance_of_doing_es_update(setup_es_benchmark, tmp_path):
posterior,
prior.experiment.observation_keys,
[gen_kw_name],
UpdateSettings(),
ESSettings(),
)

stats = memray._memray.compute_statistics(str(tmp_path / "memray.bin"))
Expand All @@ -525,6 +528,8 @@ def run():
posterior,
prior.experiment.observation_keys,
[gen_kw_name],
UpdateSettings(),
ESSettings(),
)

benchmark(run)
11 changes: 10 additions & 1 deletion tests/ert/unit_tests/scenarios/test_summary_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from ert import LibresFacade
from ert.analysis import ErtAnalysisError, smoother_update
from ert.config import ErtConfig
from ert.config import ErtConfig, ESSettings
from ert.config.analysis_config import UpdateSettings
from ert.data import MeasuredData
from ert.enkf_main import sample_prior

Expand Down Expand Up @@ -102,6 +103,8 @@ def test_that_reading_matching_time_is_ok(ert_config, storage, prior_ensemble):
target_ensemble,
prior_ensemble.experiment.observation_keys,
ert_config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(),
)


Expand Down Expand Up @@ -129,6 +132,8 @@ def test_that_mismatched_responses_give_error(ert_config, storage, prior_ensembl
target_ensemble,
prior_ensemble.experiment.observation_keys,
ert_config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(),
)


Expand Down Expand Up @@ -160,6 +165,8 @@ def test_that_different_length_is_ok_as_long_as_observation_time_exists(
target_ensemble,
prior_ensemble.experiment.observation_keys,
ert_config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(),
)


Expand Down Expand Up @@ -206,6 +213,8 @@ def test_that_duplicate_summary_time_steps_does_not_fail(
target_ensemble,
prior_ensemble.experiment.observation_keys,
ert_config.ensemble_config.parameters,
UpdateSettings(),
ESSettings(),
)


Expand Down
5 changes: 4 additions & 1 deletion tests/ert/unit_tests/storage/test_storage_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from packaging import version

from ert.analysis import ErtAnalysisError, smoother_update
from ert.config import ErtConfig
from ert.config import ErtConfig, ESSettings
from ert.config.analysis_config import UpdateSettings
from ert.storage import open_storage
from ert.storage.local_storage import (
_LOCAL_STORAGE_VERSION,
Expand Down Expand Up @@ -467,6 +468,8 @@ def test_that_manual_update_from_migrated_storage_works(
posterior_ens,
list(experiment.observation_keys),
list(ert_config.ensemble_config.parameters),
UpdateSettings(),
ESSettings(),
)


Expand Down

0 comments on commit b087917

Please sign in to comment.