From 14848bea3119c44cde1d09d3f728ed6980d35fcc Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Wed, 12 Jun 2024 11:40:23 +0200 Subject: [PATCH] Add logic & tests for realization states --- src/ert/analysis/_es_update.py | 5 +- src/ert/storage/local_ensemble.py | 319 ++++++++-- src/ert/storage/local_storage.py | 4 + src/ert/storage/realization_state.py | 316 ++++++++++ .../storage/test_parameter_sample_types.py | 2 +- tests/unit_tests/analysis/test_es_update.py | 1 + .../scenarios/test_summary_response.py | 2 +- .../migration/test_block_fs_snake_oil.py | 5 +- .../unit_tests/storage/test_local_storage.py | 1 + .../storage/test_realization_state.py | 591 ++++++++++++++++++ tests/unit_tests/test_load_forward_model.py | 1 + 11 files changed, 1206 insertions(+), 41 deletions(-) create mode 100644 src/ert/storage/realization_state.py create mode 100644 tests/unit_tests/storage/test_realization_state.py diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 87b7952e6e9..726c57f0b96 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -125,8 +125,6 @@ def _save_param_ensemble_array_to_disk( ensemble, param_group, realization, param_ensemble_array[:, i] ) - ensemble.unify_parameters() - def _load_param_ensemble_array( ensemble: Ensemble, @@ -560,6 +558,8 @@ def correlation_callback( target_ensemble, ) + target_ensemble.unify_parameters() + def analysis_IES( parameters: Iterable[str], @@ -664,6 +664,7 @@ def analysis_IES( target_ensemble, param_ensemble_array, param_group, iens_active_index ) + target_ensemble.unify_parameters() _copy_unupdated_parameters( list(source_ensemble.experiment.parameter_configuration.keys()), parameters, diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index 42f2b4afc52..60d6ef776e7 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -15,6 +15,7 @@ List, Literal, Optional, + Set, Tuple, Union, ) @@ -31,8 +32,9 @@ from ert.config.observations import ObservationsIndices from ert.storage.mode import BaseMode, Mode, require_write -from ..config import GenDataConfig, ResponseTypes +from ..config import GenDataConfig, ResponseTypes, SummaryConfig from .ensure_correct_xr_coordinate_order import ensure_correct_coordinate_order +from .realization_state import _MultiRealizationStateDict from .realization_storage_state import RealizationStorageState if TYPE_CHECKING: @@ -146,6 +148,15 @@ def create_realization_dir(realization: int) -> Path: return self._path / f"realization-{realization}" self._realization_dir = create_realization_dir + self._realization_states = _MultiRealizationStateDict() + + def on_experiment_initialized(self) -> None: + """ + Executes logic that depends on the experiment of the ensemble to exist. + For example, if some logic needs to traverse response/parameter configs, + the experiment of this ensemble must be initialized before running that logic. + """ + self._refresh_realization_states() @classmethod def create( @@ -343,14 +354,9 @@ def _parameters_exist_for_realization(self, realization: int) -> bool: """ if not self.experiment.parameter_configuration: return True - path = self._realization_dir(realization) + return all( - ( - self.has_combined_parameter_dataset(parameter) - and realization - in self._load_combined_parameter_dataset(parameter)["realizations"] - ) - or (path / f"{parameter}.nc").exists() + self._realization_states.has_parameter_group(realization, parameter) for parameter in self.experiment.parameter_configuration ) @@ -398,28 +404,13 @@ def _responses_exist_for_realization( otherwise, `False`. """ - if not self.experiment.response_configuration: - return True - - real_dir = self._realization_dir(realization) - if key: - if self.has_combined_response_dataset(key): - return ( - realization - in self._load_combined_response_dataset(key)["realization"] - ) - else: - return (real_dir / f"{key}.nc").exists() - - return all( - (real_dir / f"{response}.nc").exists() - or ( - self.has_combined_response_dataset(response) - and realization - in self._load_combined_response_dataset(response)["realization"].values + if not key: + return all( + self._realization_states.has_response(realization, response) + for response in self.experiment.response_configuration ) - for response in self.experiment.response_configuration - ) + + return self._realization_states.has_response(realization, key) def is_initalized(self) -> List[int]: """ @@ -437,12 +428,7 @@ def is_initalized(self) -> List[int]: i for i in range(self.ensemble_size) if all( - (self._realization_dir(i) / f"{parameter.name}.nc").exists() - for parameter in self.experiment.parameter_configuration.values() - if not parameter.forward_init - ) - or all( - (self._path / f"{parameter.name}.nc").exists() + self._realization_states.has_parameter_group(i, parameter.name) for parameter in self.experiment.parameter_configuration.values() if not parameter.forward_init ) @@ -1294,6 +1280,7 @@ def unify_responses(self, key: Optional[str] = None) -> None: if key is None: for key in self.experiment.response_configuration: self.unify_responses(key) + return None gen_data_keys = { k @@ -1365,6 +1352,8 @@ def unify_responses(self, key: Optional[str] = None) -> None: "realization", ) + self._refresh_realization_states_for_responses(response_key=key) + def unify_parameters(self, key: Optional[str] = None) -> None: self._unify_datasets( ( @@ -1374,6 +1363,7 @@ def unify_parameters(self, key: Optional[str] = None) -> None: ), "realizations", ) + self._refresh_realization_states_for_parameters() def get_parameter_state( self, realization: int @@ -1406,3 +1396,262 @@ def get_response_state( ) for e in self.experiment.response_configuration ) + + def _refresh_realization_states_for_parameters(self) -> None: + old_states = self._realization_states + states = _MultiRealizationStateDict() + + def _refresh_grouped_combined(): + birth_time = os.path.getctime(self._path / f"{parameter_group_key}.nc") + + combined_param_ds = self._load_combined_parameter_dataset( + parameter_group_key + ) + for realization_index in range(self.ensemble_size): + old_state = old_states.get_single_realization_state(realization_index) + + if ( + old_state.has_parameter_key_or_group(parameter_group_key) + and old_state.get_parameter(parameter_group_key).timestamp + == birth_time + ): + continue + + real_state = states.get_single_realization_state(realization_index) + + if realization_index not in combined_param_ds["realizations"]: + if "names" in combined_param_ds: + for param_key in combined_param_ds["names"].values: + real_state.set_parameter_group( + key=param_key, + value=False, + parameter_group=parameter_group_key, + source=self._path / f"{parameter_group_key}.nc", + ) + real_state.set_parameter_group( + key=parameter_group_key, + value=False, + parameter_group=parameter_group_key, + source=self._path / f"{parameter_group_key}.nc", + ) + else: + params_for_realization = combined_param_ds.sel( + realizations=realization_index, drop=True + ) + + if "names" in params_for_realization: + df = ( + params_for_realization[["names", "values"]] + .to_dataframe() + .reset_index() + ) + for _, row in df.iterrows(): + real_state.set_parameter_group( + key=row["names"], + value=not pd.isna(row["values"]), + parameter_group=parameter_group_key, + source=self._path / f"{parameter_group_key}.nc", + ) + else: + # Surfaces and fields don't have a "name", + # not the concept of a parameter group as opposed + # to GEN_KW + real_state.set_parameter_group( + key=parameter_group_key, + value=realization_index + in combined_param_ds["realizations"], + parameter_group=parameter_group_key, + source=self._path / f"{parameter_group_key}.nc", + ) + + def _refresh_grouped_not_combined(): + for realization_index in range(self.ensemble_size): + real_state = states.get_single_realization_state(realization_index) + + ds_path = ( + self._realization_dir(realization_index) + / f"{parameter_group_key}.nc" + ) + + if not os.path.exists(ds_path): + # Here we can't really know what the specific keys of + # the parameter is (if any), so we need to just set that + # the entire group was not found + real_state.set_parameter_group( + parameter_group_key, + False, + parameter_group_key, + ) + continue + + old_state = old_states.get_single_realization_state(realization_index) + + if old_state.has_parameter_key_or_group( + parameter_group_key + ) and old_state.get_parameter( + parameter_group_key + ).timestamp == os.path.getctime(ds_path): + continue + + ds = xr.open_dataset(ds_path) + for _, row in ( + ds[["names", "values"]].to_dataframe().reset_index().iterrows() + ): + real_state.set_parameter_group( + key=row["names"], + value=not pd.isna(row["values"]), + parameter_group=parameter_group_key, + source=ds_path, + ) + + for parameter_group_key in self.experiment.parameter_configuration: + if self.has_combined_parameter_dataset(parameter_group_key): + _refresh_grouped_combined() + else: + _refresh_grouped_not_combined() + + self._realization_states = ( + old_states.copy().assign_states(states).make_keys_consistent() + ) + + def _refresh_realization_states_for_responses( + self, response_key: Optional[str] = None + ) -> None: + old_states = self._realization_states + states = _MultiRealizationStateDict() + + # If it has gen data, check all names w/ non nan values + has_combined_summary = self.has_combined_response_dataset("summary") + has_combined_gendata = self.has_combined_response_dataset("gen_data") + + response_configs = [ + c + for c in self.experiment.response_configuration.values() + if response_key is None or c.name == response_key + ] + + def _refresh_for_gendata_not_combined(key: str): + for realization_index in range(self.ensemble_size): + ds_path = self._realization_dir(realization_index) / f"{key}.nc" + realization_state = states.get_single_realization_state( + realization_index + ) + + realization_state.set_response( + key=key, + value=ds_path.exists(), + response_type="gen_data", + source=ds_path, + ) + + # ._load_combined_response_dataset(key) takes pretty long and needs to be done only + # once per combined dataset, per call to this + _cached_open_combined_datasets: Dict[str, xr.Dataset] = {} + + def _refresh_for_responses_combined(response_type, response_keys: Set[str]): + _cached_open_combined_datasets[response_type] = combined_ds = ( + _cached_open_combined_datasets[response_type] + if response_type in _cached_open_combined_datasets + else self._load_combined_response_dataset(response_type) + ) + + ds_time_of_birth = os.path.getctime(self._path / f"{response_type}.nc") + + for realization_index in range(self.ensemble_size): + real_state = states.get_single_realization_state(realization_index) + + if ( + real_state.has_response_key_or_group(response_type) + and real_state.get_response(response_type).timestamp + == ds_time_of_birth + ): + continue + + if realization_index not in combined_ds["realization"]: + for expected_key in [response_key, *response_keys]: + real_state.set_response( + key=expected_key, + value=False, + response_type=response_type, + source=self._path / f"{response_type}.nc", + ) + + continue + + summary_4real = combined_ds.sel(realization=realization_index).dropna( + "name", how="all" + )["name"] + found_keys = summary_4real["name"].data + + for expected_key in response_keys: + real_state.set_response( + expected_key, + value=expected_key in found_keys, + response_type=response_type, + source=self._path / f"{response_type}.nc", + ) + + def _refresh_for_summary_not_combined(expected_keys: Set[str]): + for realization_index in range(self.ensemble_size): + ds_path = self._realization_dir(realization_index) / "summary.nc" + realization_state = states.get_single_realization_state( + realization_index + ) + + if ds_path.exists(): + if realization_state.has_response_key_or_group( + "summary" + ) and realization_state.get_response( + "summary" + ).timestamp == os.path.getctime(ds_path): + continue + + ds = xr.open_dataset(ds_path) + for smry_key in expected_keys: + realization_state.set_response( + key=smry_key, + value=smry_key in ds["name"], + response_type="summary", + source=ds_path, + ) + else: + for smry_key in expected_keys: + realization_state.set_response( + key=smry_key, + value=False, + response_type="summary", + source=ds_path, + ) + + keys_per_response_type: Dict[str, Set[str]] = {} + + for response_config in response_configs: + if isinstance(response_config, GenDataConfig): + if "gen_data" not in keys_per_response_type: + keys_per_response_type["gen_data"] = set() + + keys_per_response_type["gen_data"].add(response_config.name) + + elif isinstance(response_config, SummaryConfig): + assert isinstance(response_config, SummaryConfig) + if "summary" not in keys_per_response_type: + keys_per_response_type["summary"] = set() + + keys_per_response_type["summary"].update(response_config.keys) + + for response_type, expected_keys in keys_per_response_type.items(): + if self.has_combined_response_dataset(response_type): + _refresh_for_responses_combined(response_type, expected_keys) + elif response_type == "gen_data": + # Not combined and one file per response key + for k in expected_keys: + _refresh_for_gendata_not_combined(k) + else: + # Not combined and one file with multiple response keys + _refresh_for_summary_not_combined(expected_keys) + + self._realization_states = old_states.copy().assign_states(states) + + def _refresh_realization_states(self, response_key: Optional[str] = None) -> None: + self._refresh_realization_states_for_responses(response_key) + self._refresh_realization_states_for_parameters() diff --git a/src/ert/storage/local_storage.py b/src/ert/storage/local_storage.py index 9647910f5e4..f2029a8c424 100644 --- a/src/ert/storage/local_storage.py +++ b/src/ert/storage/local_storage.py @@ -130,6 +130,9 @@ def refresh(self) -> None: self._ensembles = self._load_ensembles() self._experiments = self._load_experiments() + for ens in self._ensembles.values(): + ens.on_experiment_initialized() + def get_experiment(self, uuid: UUID) -> LocalExperiment: """ Retrieves an experiment by UUID. @@ -404,6 +407,7 @@ def create_ensemble( name=str(name), prior_ensemble_id=prior_ensemble_id, ) + if prior_ensemble: for realization, state in enumerate(prior_ensemble.get_ensemble_state()): if state in [ diff --git a/src/ert/storage/realization_state.py b/src/ert/storage/realization_state.py new file mode 100644 index 00000000000..0fe5242ed4c --- /dev/null +++ b/src/ert/storage/realization_state.py @@ -0,0 +1,316 @@ +import dataclasses +import os +import time +from typing import Dict, List, Optional, Set, Tuple + +import pandas +from typing_extensions import Self + + +@dataclasses.dataclass +class _SingleRealizationStateDictEntry: + value: bool = dataclasses.field(default=False) + timestamp: float = dataclasses.field(default=-1) + + def update(self, value: bool, timestamp: float = -1): + if timestamp is None: + timestamp = time.time() + + self.value = value + self.timestamp = timestamp + + def copy(self) -> "_SingleRealizationStateDictEntry": + return _SingleRealizationStateDictEntry( + value=self.value, timestamp=self.timestamp + ) + + def assign_state(self, src_state: "_SingleRealizationStateDictEntry") -> Self: + if src_state.timestamp == -1 and self.timestamp != -1: + return self + + if src_state.timestamp == -1 and self.timestamp == -1: + # TODO branch may not be needed + return self + + if src_state.timestamp > self.timestamp: + self.value = src_state.value + self.timestamp = src_state.timestamp + + return self + + +class _SingleRealizationStateDict: + def __init__(self) -> None: + self._items_by_kind: Dict[str, Dict[str, _SingleRealizationStateDictEntry]] = {} + + def _set_item( + self, key: str, value: bool, kind: str, source: Optional[os.PathLike] = None + ): + if key == kind and kind in self._items_by_kind: + for k in set(self._items_by_kind[kind]) - {kind}: + self._set_item(k, value, kind, source) + + return + + if kind not in self._items_by_kind: + self._items_by_kind[kind] = {} + + items_for_kind = self._items_by_kind[kind] + + timestamp = os.path.getctime(source) if (source is not None and value) else -1 + + if key not in items_for_kind: + items_for_kind[key] = _SingleRealizationStateDictEntry( + value=value, timestamp=timestamp + ) + + items_for_kind[key].update(value, timestamp) + + def set_response( + self, + key: str, + value: bool, + response_type: str, + source: Optional[os.PathLike] = None, + ): + self._set_item(key=key, value=value, kind=response_type, source=source) + + def set_parameter_group( + self, + key: str, + value: bool, + parameter_group: str, + source: Optional[os.PathLike] = None, + ): + self._set_item(key=key, value=value, kind=parameter_group, source=source) + + def _lookup_single_kind_dict_for_key( + self, key: str + ) -> Dict[str, _SingleRealizationStateDictEntry]: + matches = [ + (kind, kind_dict) + for kind, kind_dict in self._items_by_kind.items() + if key in kind_dict + ] + + if len(matches) == 0: + return {} + + assert len(matches) == 1, ( + f"Expected to find only one matching" + f" kind for key {key}, but found " + f"{', '.join([k for k,_ in matches])}" + ) + return matches[0][1] + + def has_response_key_or_group(self, key: str) -> bool: + if key in self._items_by_kind: + # It is a response type + return any(x.value for x in self._items_by_kind[key].values()) + + matching_kind_dict = self._lookup_single_kind_dict_for_key(key) + + return key in matching_kind_dict and matching_kind_dict[key].value + + def has_parameter_key_or_group(self, key: str): + if key in self._items_by_kind: + # It is a parameter group + # they are always all written at the same time, + # question2reviewer: If they are all nan, it means that + # the parameter was somehow sampled and ended up being all nan + # does that mean that it still HAS the parameter as in it is + # "something", or does that mean we should return False here? + # current assumption: We need at least one non NaN + return any(x.value for x in self._items_by_kind[key].values()) + + matching_kind_dict = self._lookup_single_kind_dict_for_key(key) + + return key in matching_kind_dict and matching_kind_dict[key] + + def get_response(self, key_or_group: str) -> _SingleRealizationStateDictEntry: + if key_or_group in self._items_by_kind: + kind_dicts = self._items_by_kind[key_or_group].values() + return _SingleRealizationStateDictEntry( + value=any(x.value for x in kind_dicts), + timestamp=max(x.timestamp for x in kind_dicts), + ) + + kind_dict = self._lookup_single_kind_dict_for_key(key_or_group) + return kind_dict.get(key_or_group) + + def get_parameter(self, key_or_group: str) -> _SingleRealizationStateDictEntry: + if key_or_group in self._items_by_kind: + kind_dicts = self._items_by_kind[key_or_group].values() + return _SingleRealizationStateDictEntry( + value=any(x.value for x in kind_dicts), + timestamp=max(x.timestamp for x in kind_dicts), + ) + + kind_dict = self._lookup_single_kind_dict_for_key(key_or_group) + + entry = kind_dict.get(key_or_group) + assert entry is not None + return entry + + def copy(self) -> "_SingleRealizationStateDict": + cpy = _SingleRealizationStateDict() + cpy._items_by_kind = { + k: {kind: entry.copy() for kind, entry in kind_to_entries.items()} + for k, kind_to_entries in self._items_by_kind.items() + } + + return cpy + + def make_keys_consistent(self, keys_per_kind: Dict[str, Set[str]]) -> None: + for kind, items in self._items_by_kind.items(): + if set(items) == {kind}: + entry = items[kind] + for key in keys_per_kind[kind]: + items[key] = entry.copy() + + for kind, items in self._items_by_kind.items(): + if kind in set(items) and set(items) != {kind}: + del items[kind] + + def assign_state(self, src_state: "_SingleRealizationStateDict") -> Self: + for src_kind, src_items_by_kind in src_state._items_by_kind.items(): + if set(src_items_by_kind) == {src_kind}: + # Set all existing keys of this state + if src_kind not in self._items_by_kind: + self._items_by_kind[src_kind] = { + src_kind: src_items_by_kind[src_kind].copy() + } + + if set(self._items_by_kind[src_kind]) == {src_kind}: + self._items_by_kind[src_kind][src_kind] = src_items_by_kind[ + src_kind + ].copy() + else: + for k in self._items_by_kind[src_kind]: + self._items_by_kind[src_kind][k] = src_items_by_kind[ + src_kind + ].copy() + continue + + elif src_kind not in self._items_by_kind: + self._items_by_kind[src_kind] = { + k: v.copy() for k, v in src_items_by_kind.items() + } + continue + + src_keys_for_kind = set(src_items_by_kind) + my_keys_for_kind = set(self._items_by_kind[src_kind]) + all_keys = src_keys_for_kind.union(my_keys_for_kind) + + if src_keys_for_kind == {src_kind}: + # src has all keys for kind set to the same thing + state_for_all = src_state._items_by_kind[src_kind][src_kind] + for k in all_keys - {src_kind}: + self._items_by_kind[src_kind][k] = state_for_all.copy() + continue + + for k in all_keys: + if k in src_keys_for_kind: + src_state_entry = src_items_by_kind[k] + if k not in my_keys_for_kind: + self._items_by_kind[src_kind][k] = src_state_entry.copy() + elif k in my_keys_for_kind: + my_state = self._items_by_kind[src_kind][k] + self._items_by_kind[src_kind][k] = my_state.copy().assign_state( + src_state_entry + ) + + return self + + def to_tuples(self) -> List[Tuple[str, str, _SingleRealizationStateDictEntry]]: + tuples = [] + for kind, items_for_kind in self._items_by_kind.items(): + for key, entry in items_for_kind.items(): + tuples.append((kind, key, entry)) + + return tuples + + +class _MultiRealizationStateDict: + def __init__(self) -> None: + self._items: Dict[int, _SingleRealizationStateDict] = {} + + def has_response(self, realization: int, key: str) -> bool: + if realization not in self._items: + return False + + return self._items[realization].has_response_key_or_group(key) + + def has_parameter_group(self, realization: int, key: str) -> bool: + if realization not in self._items: + return False + + return self._items[realization].has_parameter_key_or_group(key) + + def is_empty(self) -> bool: + return self._items == {} + + def get_single_realization_state( + self, realization: int + ) -> _SingleRealizationStateDict: + if realization not in self._items: + self._items[realization] = _SingleRealizationStateDict() + + return self._items[realization] + + def copy(self) -> "_MultiRealizationStateDict": + cpy = _MultiRealizationStateDict() + cpy._items = { + realization_index: state.copy() + for realization_index, state in self._items.items() + } + return cpy + + def assign_states(self, source: "_MultiRealizationStateDict") -> Self: + for realization_index, realization_state in source._items.items(): + if realization_index not in self._items: + self._items[realization_index] = realization_state.copy() + else: + self._items[realization_index].assign_state(realization_state) + + return self + + def make_keys_consistent(self) -> Self: + keys_per_kind: Dict[str, Set[str]] = {} + for state in self._items.values(): + for kind, key, _ in state.to_tuples(): + if kind not in keys_per_kind: + keys_per_kind[kind] = set() + + keys_per_kind[kind].add(key) + + for kind, keys in keys_per_kind.items(): + if set(keys) != {kind} and kind in keys: + keys.remove(kind) + + for state in self._items.values(): + state.make_keys_consistent(keys_per_kind) + + return self + + def to_dataframe(self) -> pandas.DataFrame: + # One column per realization + # One row per kind-key + rows = [] + for real, state in self._items.items(): + for kind, key, entry in state.to_tuples(): + rows.append((real, kind, key, entry)) + + return ( + pandas.DataFrame( + data={ + "realization": [row[0] for row in rows], + "kind": [row[1] for row in rows], + "key": [row[2] for row in rows], + "value": [row[3].value for row in rows], + "timestamp": [row[3].timestamp for row in rows], + } + ) + .set_index(["realization", "kind", "key"]) + .sort_values(["realization", "kind", "key"]) + ) diff --git a/tests/integration_tests/storage/test_parameter_sample_types.py b/tests/integration_tests/storage/test_parameter_sample_types.py index 523babda1de..54ea6918e81 100644 --- a/tests/integration_tests/storage/test_parameter_sample_types.py +++ b/tests/integration_tests/storage/test_parameter_sample_types.py @@ -178,7 +178,7 @@ def test_surface_param_update(tmpdir): @pytest.mark.integration_test @pytest.mark.limit_memory("130 MB") -@pytest.mark.flaky(reruns=5) +@pytest.mark.flaky(reruns=0) def test_field_param_memory(tmpdir): with tmpdir.as_cwd(): # Setup is done in a subprocess so that memray does not pick up the allocations diff --git a/tests/unit_tests/analysis/test_es_update.py b/tests/unit_tests/analysis/test_es_update.py index ecc4465af7a..fdc2e30a64c 100644 --- a/tests/unit_tests/analysis/test_es_update.py +++ b/tests/unit_tests/analysis/test_es_update.py @@ -434,6 +434,7 @@ def test_smoother_snapshot_alpha( sies_smoother = None # The initial_mask equals ens_mask on first iteration + prior_storage._refresh_realization_states() initial_mask = prior_storage.get_realization_mask_with_responses() with expectation: diff --git a/tests/unit_tests/scenarios/test_summary_response.py b/tests/unit_tests/scenarios/test_summary_response.py index 8657783dc54..bdcdf85dd51 100644 --- a/tests/unit_tests/scenarios/test_summary_response.py +++ b/tests/unit_tests/scenarios/test_summary_response.py @@ -78,7 +78,7 @@ def create_responses(config_file, prior_ensemble, response_times): facade.load_from_forward_model( prior_ensemble, [True] * facade.get_ensemble_size(), 0 ) - prior_ensemble.unify_responses() + prior_ensemble._refresh_realization_states() def test_that_reading_matching_time_is_ok(ert_config, storage, prior_ensemble): diff --git a/tests/unit_tests/storage/migration/test_block_fs_snake_oil.py b/tests/unit_tests/storage/migration/test_block_fs_snake_oil.py index 9f92a9ab7b7..1dc3c408e50 100644 --- a/tests/unit_tests/storage/migration/test_block_fs_snake_oil.py +++ b/tests/unit_tests/storage/migration/test_block_fs_snake_oil.py @@ -9,7 +9,7 @@ import ert.storage import ert.storage.migration.block_fs as bf -from ert.config import ErtConfig, GenKwConfig +from ert.config import ErtConfig, GenDataConfig, GenKwConfig from ert.config.summary_config import SummaryConfig from ert.storage import open_storage from ert.storage.local_storage import local_storage_set_ert_config @@ -102,6 +102,7 @@ def test_migrate_summary(data, forecast, time_map, tmp_path): ] ) ensemble = experiment.create_ensemble(name="default_0", ensemble_size=5) + ensemble._refresh_realization_states() bf._migrate_summary(ensemble, forecast, time_map) ensemble.unify_responses() @@ -124,7 +125,7 @@ def test_migrate_gen_data(data, forecast, tmp_path): with open_storage(tmp_path / "storage", mode="w") as storage: experiment = storage.create_experiment( responses=[ - SummaryConfig(name=name, input_file="some_file", keys=["some_key"]) + GenDataConfig(name=name, input_file="some_file") for name in ( "SNAKE_OIL_WPR_DIFF", "SNAKE_OIL_OPR_DIFF", diff --git a/tests/unit_tests/storage/test_local_storage.py b/tests/unit_tests/storage/test_local_storage.py index 5f04ab959af..5080073ae5e 100644 --- a/tests/unit_tests/storage/test_local_storage.py +++ b/tests/unit_tests/storage/test_local_storage.py @@ -480,6 +480,7 @@ class Experiment: observations: Dict[str, xr.Dataset] = field(default_factory=dict) +# @reproduce_failure('6.103.2', b'AXicY2BgZCARMOJgg7mMYDFuRgiNz2xGDB5EhBGhj5EBAAeeAB0=') class StatefulStorageTest(RuleBasedStateMachine): """ This test runs several commands against storage and diff --git a/tests/unit_tests/storage/test_realization_state.py b/tests/unit_tests/storage/test_realization_state.py new file mode 100644 index 00000000000..e0c2c8ffc15 --- /dev/null +++ b/tests/unit_tests/storage/test_realization_state.py @@ -0,0 +1,591 @@ +import os + +import pandas as pd + +from ert.config import GenDataConfig, GenKwConfig, SummaryConfig +from ert.config.gen_kw_config import TransformFunctionDefinition +from ert.storage import open_storage +from ert.storage.realization_state import ( + _SingleRealizationStateDict, + _SingleRealizationStateDictEntry, +) +from tests.performance_tests.test_memory_usage import make_gen_data, make_summary_data + + +def test_that_realization_states_with_no_params_or_responses_shows_empty(tmp_path): + with open_storage(tmp_path, mode="w") as storage: + experiment = storage.create_experiment() + ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=1) + + ensemble._refresh_realization_states() + + states = ensemble._realization_states + assert states.is_empty() + assert ensemble._realization_states.to_dataframe().empty + + +def test_that_realization_states_shows_all_params_present(tmp_path): + with open_storage(tmp_path, mode="w") as storage: + experiment = storage.create_experiment( + parameters=[ + GenKwConfig( + name="PARAMETER_GROUP", + forward_init=False, + template_file="", + transform_function_definitions=[ + TransformFunctionDefinition("KEY1", "UNIFORM", [0, 1]), + TransformFunctionDefinition("KEY2", "UNIFORM", [0, 1]), + TransformFunctionDefinition("KEY3", "UNIFORM", [0, 1]), + ], + output_file="kw.txt", + update=True, + ) + ] + ) + ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=25) + for i in range(1, 25): + ensemble.save_parameters( + "PARAMETER_GROUP", + i, + pd.DataFrame( + data={ + "names": ["KEY1", "KEY2", "KEY3"], + "values": [1, 2, 3], + "transformed_values": [2, 4, 6], + } + ) + .set_index(["names"]) + .to_xarray(), + ) + + ensemble._refresh_realization_states() + state_df_before_combine = ensemble._realization_states.to_dataframe() + + for i in range(1, 25): + real_state = ensemble._realization_states.get_single_realization_state(i) + ds_path = ensemble._realization_dir(i) / "PARAMETER_GROUP.nc" + + tob = os.path.getctime(ds_path) if os.path.exists(ds_path) else -1 + + assert real_state.has_parameter_key_or_group("PARAMETER_GROUP") + assert real_state.get_parameter("PARAMETER_GROUP").timestamp == tob + + ensemble.unify_parameters() + state_df_after_combine = ensemble._realization_states.to_dataframe() + + assert state_df_before_combine["value"].equals(state_df_after_combine["value"]) + assert ( + sum( + state_df_after_combine["timestamp"] + - state_df_before_combine["timestamp"] + ) + > 0 + ) + + ds_path = ensemble._path / "PARAMETER_GROUP.nc" + tob = os.path.getctime(ds_path) + + real_state0 = ensemble._realization_states.get_single_realization_state(0) + assert not real_state0.has_parameter_key_or_group("PARAMETER_GROUP") + + for i in range(1, 25): + real_state = ensemble._realization_states.get_single_realization_state(i) + + assert real_state.has_parameter_key_or_group("PARAMETER_GROUP") + assert real_state.get_parameter("PARAMETER_GROUP").timestamp == tob + + +def test_that_realization_states_shows_some_params_present(tmp_path): + with open_storage(tmp_path, mode="w") as storage: + experiment = storage.create_experiment( + parameters=[ + GenKwConfig( + name="PARAMETER_GROUP", + forward_init=False, + template_file="", + transform_function_definitions=[ + TransformFunctionDefinition("KEY1", "UNIFORM", [0, 1]), + TransformFunctionDefinition("KEY2", "UNIFORM", [0, 1]), + TransformFunctionDefinition("KEY3", "UNIFORM", [0, 1]), + ], + output_file="kw.txt", + update=True, + ), + GenKwConfig( + name="PARAMETER_GROUP2", + forward_init=False, + template_file="", + transform_function_definitions=[ + TransformFunctionDefinition("KEY11", "UNIFORM", [0, 1]), + TransformFunctionDefinition("KEY21", "UNIFORM", [0, 1]), + TransformFunctionDefinition("KEY31", "UNIFORM", [0, 1]), + ], + output_file="kw.txt", + update=True, + ), + ] + ) + ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=25) + for i in range(1, 25): + if i % 2 == 0: + ensemble.save_parameters( + "PARAMETER_GROUP", + i, + pd.DataFrame( + data={ + "names": ["KEY1", "KEY2", "KEY3"], + "values": [1, 2, 3], + "transformed_values": [2, 4, 6], + } + ) + .set_index(["names"]) + .to_xarray(), + ) + + if i % 3 == 0: + ensemble.save_parameters( + "PARAMETER_GROUP2", + i, + pd.DataFrame( + data={ + "names": ["KEY1", "KEY2", "KEY3"], + "values": [1, 2, 3], + "transformed_values": [2, 4, 6], + } + ) + .set_index(["names"]) + .to_xarray(), + ) + ensemble._refresh_realization_states() + state_df_before_combine = ensemble._realization_states.to_dataframe() + + for i in range(1, 25): + ds_path_1 = ensemble._realization_dir(i) / "PARAMETER_GROUP.nc" + ds_path_2 = ensemble._realization_dir(i) / "PARAMETER_GROUP2.nc" + + tob_1 = os.path.getctime(ds_path_1) if os.path.exists(ds_path_1) else -1 + tob_2 = os.path.getctime(ds_path_2) if os.path.exists(ds_path_2) else -1 + + real_state = ensemble._realization_states.get_single_realization_state(i) + if i % 6 == 0: + assert real_state.has_parameter_key_or_group("PARAMETER_GROUP") + assert real_state.has_parameter_key_or_group("PARAMETER_GROUP2") + assert real_state.get_parameter("PARAMETER_GROUP").timestamp == tob_1 + assert real_state.get_parameter("PARAMETER_GROUP2").timestamp == tob_2 + elif i % 2 == 0: + assert real_state.has_parameter_key_or_group("PARAMETER_GROUP") + assert not real_state.has_parameter_key_or_group("PARAMETER_GROUP2") + assert real_state.get_parameter("PARAMETER_GROUP").timestamp == tob_1 + elif i % 3 == 0: + assert not real_state.has_parameter_key_or_group("PARAMETER_GROUP") + assert real_state.has_parameter_key_or_group("PARAMETER_GROUP2") + assert real_state.get_parameter("PARAMETER_GROUP2").timestamp == tob_2 + else: + assert not real_state.has_parameter_key_or_group("PARAMETER_GROUP2") + assert not real_state.has_parameter_key_or_group("PARAMETER_GROUP") + + ensemble.unify_parameters() + + states = ensemble._realization_states + state_df_after_combine = ensemble._realization_states.to_dataframe() + + assert state_df_before_combine["value"].equals(state_df_after_combine["value"]) + assert ( + sum( + state_df_after_combine["timestamp"] + - state_df_before_combine["timestamp"] + ) + > 0 + ) + + tob_1 = os.path.getctime(ensemble._path / "PARAMETER_GROUP.nc") + tob_2 = os.path.getctime(ensemble._path / "PARAMETER_GROUP2.nc") + for i in range(1, 25): + real_state = states.get_single_realization_state(i) + + if i % 6 == 0: + assert real_state.get_parameter("PARAMETER_GROUP").timestamp == tob_1 + assert real_state.get_parameter("PARAMETER_GROUP2").timestamp == tob_2 + elif i % 2 == 0: + assert real_state.get_parameter("PARAMETER_GROUP").timestamp == tob_1 + + assert not real_state.has_parameter_key_or_group("PARAMETER_GROUP2") + elif i % 3 == 0: + assert not real_state.has_parameter_key_or_group("PARAMETER_GROUP") + assert real_state.get_parameter("PARAMETER_GROUP2").timestamp == tob_2 + else: + assert not real_state.has_parameter_key_or_group("PARAMETER_GROUP2") + assert not real_state.has_parameter_key_or_group("PARAMETER_GROUP") + + +def test_that_realization_states_update_after_rewrite_realization(tmp_path): + with open_storage(tmp_path, mode="w") as storage: + experiment = storage.create_experiment( + parameters=[ + GenKwConfig( + name="PARAMETER_GROUP", + forward_init=False, + template_file="", + transform_function_definitions=[ + TransformFunctionDefinition("KEY1", "UNIFORM", [0, 1]), + TransformFunctionDefinition("KEY2", "UNIFORM", [0, 1]), + TransformFunctionDefinition("KEY3", "UNIFORM", [0, 1]), + ], + output_file="kw.txt", + update=True, + ) + ] + ) + ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=25) + for i in range(1, 25): + ensemble.save_parameters( + "PARAMETER_GROUP", + i, + pd.DataFrame( + data={ + "names": ["KEY1", "KEY2", "KEY3"], + "values": [1, 2, 3], + "transformed_values": [2, 4, 6], + } + ) + .set_index(["names"]) + .to_xarray(), + ) + + ensemble._refresh_realization_states() + state_df_before_remove_param = ensemble._realization_states.to_dataframe() + os.rename( + ensemble._realization_dir(1) / "PARAMETER_GROUP.nc", + ensemble._realization_dir(1) / "PARAMETER_GROUP_TMP.nc", + ) + ensemble._refresh_realization_states() + state_df_after_remove_param = ensemble._realization_states.to_dataframe() + assert not state_df_after_remove_param.equals(state_df_before_remove_param) + assert state_df_after_remove_param.drop(1).equals( + state_df_before_remove_param.drop(1) + ) + + real_state1 = ensemble._realization_states.get_single_realization_state(1) + assert not real_state1.has_parameter_key_or_group("PARAMETER_GROUP") + + os.rename( + ensemble._realization_dir(1) / "PARAMETER_GROUP_TMP.nc", + ensemble._realization_dir(1) / "PARAMETER_GROUP.nc", + ) + ensemble._refresh_realization_states() + real_state1 = ensemble._realization_states.get_single_realization_state(1) + assert real_state1.has_parameter_key_or_group("PARAMETER_GROUP") + assert real_state1.get_parameter( + "PARAMETER_GROUP" + ).timestamp == os.path.getctime( + ensemble._realization_dir(1) / "PARAMETER_GROUP.nc" + ) + + ensemble.unify_parameters() + real_state1 = ensemble._realization_states.get_single_realization_state(1) + assert real_state1.has_parameter_key_or_group("PARAMETER_GROUP") + assert real_state1.get_parameter( + "PARAMETER_GROUP" + ).timestamp == os.path.getctime(ensemble._path / "PARAMETER_GROUP.nc") + + +def test_that_realization_states_shows_all_responses_present(tmp_path): + with open_storage(tmp_path, mode="w") as storage: + experiment = storage.create_experiment( + responses=[ + GenDataConfig(name="WOPR_OP1"), + GenDataConfig(name="WOPR_OP2"), + SummaryConfig( + name="summary", input_file=None, keys=["one", "two", "three"] + ), + ], + ) + ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=25) + for i in range(1, 25): + if i % 2 == 0: + ensemble.save_response( + "summary", + make_summary_data( + ["one", "two", "three"], ["2011-01-01", "2011-02-01"] + ), + i, + ) + + if i % 3 == 0: + ensemble.save_response("WOPR_OP1", make_gen_data(20), i) + + if i % 5 == 0: + ensemble.save_response("WOPR_OP2", make_gen_data(20), i) + + ensemble._refresh_realization_states() + states = ensemble._realization_states + for i in range(1, 25): + real_state = states.get_single_realization_state(i) + + assert real_state.has_response_key_or_group("one") == (i % 2 == 0) + assert real_state.has_response_key_or_group("two") == (i % 2 == 0) + assert real_state.has_response_key_or_group("three") == (i % 2 == 0) + + if i % 2 == 0: + assert real_state.get_response("one").timestamp == os.path.getctime( + ensemble._realization_dir(i) / "summary.nc" + ) + assert real_state.get_response("two").timestamp == os.path.getctime( + ensemble._realization_dir(i) / "summary.nc" + ) + assert real_state.get_response("three").timestamp == os.path.getctime( + ensemble._realization_dir(i) / "summary.nc" + ) + + assert real_state.has_response_key_or_group("WOPR_OP1") == (i % 3 == 0) + if i % 3 == 0: + assert real_state.get_response( + "WOPR_OP1" + ).timestamp == os.path.getctime( + ensemble._realization_dir(i) / "WOPR_OP1.nc" + ) + + if i % 5 == 0: + assert real_state.has_response_key_or_group("WOPR_OP2") + assert real_state.get_response( + "WOPR_OP2" + ).timestamp == os.path.getctime( + ensemble._realization_dir(i) / "WOPR_OP2.nc" + ) + else: + assert not real_state.has_response_key_or_group("WOPR_OP2") + + ensemble.unify_responses() + states = ensemble._realization_states + smry_tobs = os.path.getctime(ensemble._path / "summary.nc") + gen_data_tobs = os.path.getctime(ensemble._path / "gen_data.nc") + + for i in range(1, 25): + real_state = states.get_single_realization_state(i) + + assert real_state.has_response_key_or_group("one") == (i % 2 == 0) + assert real_state.has_response_key_or_group("two") == (i % 2 == 0) + assert real_state.has_response_key_or_group("three") == (i % 2 == 0) + assert real_state.has_response_key_or_group("WOPR_OP1") == (i % 3 == 0) + assert real_state.has_response_key_or_group("WOPR_OP2") == (i % 5 == 0) + + if real_state.has_response_key_or_group("summmary"): + assert real_state.get_response("one").timestamp == smry_tobs + assert real_state.get_response("two").timestamp == smry_tobs + assert real_state.get_response("three").timestamp == smry_tobs + assert real_state.get_response("summary").timestamp == smry_tobs + + if real_state.has_response_key_or_group("WOPR_OP1"): + assert real_state.get_response("WOPR_OP1").timestamp == gen_data_tobs + assert real_state.get_response("gen_data").timestamp == gen_data_tobs + + if real_state.has_response_key_or_group("WOPR_OP2"): + assert real_state.get_response("WOPR_OP2").timestamp == gen_data_tobs + assert real_state.get_response("gen_data").timestamp == gen_data_tobs + + +def test_single_realization_state_transfer_clear_responses(): + state_old = _SingleRealizationStateDict() + state_old._items_by_kind = { + "summary": { + "A": _SingleRealizationStateDictEntry(value=True, timestamp=1), + "B": _SingleRealizationStateDictEntry(value=False, timestamp=-1), + "C": _SingleRealizationStateDictEntry(value=True, timestamp=1), + } + } + + assert state_old.has_response_key_or_group("A") + assert state_old.get_response("A").timestamp == 1 + + assert not state_old.has_response_key_or_group("B") + assert state_old.get_response("B").timestamp == -1 + + assert state_old.has_response_key_or_group("C") + assert state_old.get_response("C").timestamp == 1 + + state_emptying = _SingleRealizationStateDict() + state_emptying._items_by_kind = { + "summary": { + "A": _SingleRealizationStateDictEntry(value=False, timestamp=8888), + "B": _SingleRealizationStateDictEntry(value=True, timestamp=-2), + "C": _SingleRealizationStateDictEntry(value=False, timestamp=9999), + } + } + + assert not state_emptying.has_response_key_or_group("A") + assert state_emptying.get_response("A").timestamp == 8888 + + assert state_emptying.has_response_key_or_group("B") + assert state_emptying.get_response("B").timestamp == -2 + + assert not state_emptying.has_response_key_or_group("C") + assert state_emptying.get_response("C").timestamp == 9999 + + state_with_no_responses = state_old.copy().assign_state(state_emptying) + + assert not state_with_no_responses.has_response_key_or_group("A") + assert state_with_no_responses.get_response("A").timestamp == 8888 + + # We expect the -1 timestamp to be kept as it is greater than the + # -2 timestamp + assert not state_with_no_responses.has_response_key_or_group("B") + assert state_with_no_responses.get_response("B").timestamp == -1 + + assert not state_with_no_responses.has_response_key_or_group("C") + assert state_with_no_responses.get_response("C").timestamp == 9999 + + +def test_single_realization_state_transfer_clear_all_but_one_response(): + state_old = _SingleRealizationStateDict() + state_old._items_by_kind = { + "summary": { + "A": _SingleRealizationStateDictEntry(value=True, timestamp=999), + "B": _SingleRealizationStateDictEntry(value=False, timestamp=-1), + "C": _SingleRealizationStateDictEntry(value=True, timestamp=888), + } + } + + assert state_old.has_response_key_or_group("A") + assert state_old.get_response("A").timestamp == 999 + + assert not state_old.has_response_key_or_group("B") + assert state_old.get_response("B").timestamp == -1 + + assert state_old.has_response_key_or_group("C") + assert state_old.get_response("C").timestamp == 888 + + state_emptying = _SingleRealizationStateDict() + state_emptying._items_by_kind = { + "summary": { + "A": _SingleRealizationStateDictEntry(value=False, timestamp=1010), + "B": _SingleRealizationStateDictEntry(value=False, timestamp=-1), + "C": _SingleRealizationStateDictEntry(value=True, timestamp=9999), + } + } + + assert not state_emptying.has_response_key_or_group("A") + assert state_emptying.get_response("A").timestamp == 1010 + + assert not state_emptying.has_response_key_or_group("B") + assert state_emptying.get_response("B").timestamp == -1 + + assert state_emptying.has_response_key_or_group("C") + assert state_emptying.get_response("C").timestamp == 9999 + + states_with_no_responses = state_old.copy().assign_state(state_emptying) + + assert not states_with_no_responses.has_response_key_or_group("A") + assert states_with_no_responses.get_response("A").timestamp == 1010 + + assert not states_with_no_responses.has_response_key_or_group("B") + assert states_with_no_responses.get_response("B").timestamp == -1 + + assert states_with_no_responses.has_response_key_or_group("C") + assert states_with_no_responses.get_response("C").timestamp == 9999 + + +def test_single_realization_state_transfer_from_state_without_any_responses(): + state_old = _SingleRealizationStateDict() + state_old._items_by_kind = { + "summary": { + "A": _SingleRealizationStateDictEntry(value=True, timestamp=999), + "B": _SingleRealizationStateDictEntry(value=False, timestamp=-1), + "C": _SingleRealizationStateDictEntry(value=True, timestamp=888), + } + } + + assert state_old.has_response_key_or_group("A") + assert state_old.get_response("A").timestamp == 999 + + assert not state_old.has_response_key_or_group("B") + assert state_old.get_response("B").timestamp == -1 + + assert state_old.has_response_key_or_group("C") + assert state_old.get_response("C").timestamp == 888 + + state_emptying = _SingleRealizationStateDict() + state_emptying._items_by_kind = { + "summary": { + "A": _SingleRealizationStateDictEntry(value=False, timestamp=1010), + "B": _SingleRealizationStateDictEntry(value=False, timestamp=-1), + "C": _SingleRealizationStateDictEntry(value=False, timestamp=9999), + } + } + + assert not state_emptying.has_response_key_or_group("A") + assert state_emptying.get_response("A").timestamp == 1010 + + assert not state_emptying.has_response_key_or_group("B") + assert state_emptying.get_response("B").timestamp == -1 + + assert not state_emptying.has_response_key_or_group("C") + assert state_emptying.get_response("C").timestamp == 9999 + + states_with_no_responses = state_old.copy().assign_state(state_emptying) + + assert not states_with_no_responses.has_response_key_or_group("A") + assert states_with_no_responses.get_response("A").timestamp == 1010 + + assert not states_with_no_responses.has_response_key_or_group("B") + assert states_with_no_responses.get_response("B").timestamp == -1 + + assert not states_with_no_responses.has_response_key_or_group("C") + assert states_with_no_responses.get_response("C").timestamp == 9999 + + +def test_single_realization_state_transfer_with_new_responses(): + state_old = _SingleRealizationStateDict() + state_old._items_by_kind = { + "summary": { + "A": _SingleRealizationStateDictEntry(value=True, timestamp=999), + "B": _SingleRealizationStateDictEntry(value=False, timestamp=-1), + "C": _SingleRealizationStateDictEntry(value=True, timestamp=888), + } + } + + assert state_old.has_response_key_or_group("A") + assert state_old.get_response("A").timestamp == 999 + + assert not state_old.has_response_key_or_group("B") + assert state_old.get_response("B").timestamp == -1 + + assert state_old.has_response_key_or_group("C") + assert state_old.get_response("C").timestamp == 888 + + state_with_more = _SingleRealizationStateDict() + state_with_more._items_by_kind = { + "summary": { + "AA": _SingleRealizationStateDictEntry(value=True, timestamp=1010), + "BB": _SingleRealizationStateDictEntry(value=True, timestamp=-1), + "CC": _SingleRealizationStateDictEntry(value=True, timestamp=9999), + } + } + + assert state_with_more.has_response_key_or_group("AA") + assert state_with_more.get_response("AA").timestamp == 1010 + + assert state_with_more.has_response_key_or_group("BB") + assert state_with_more.get_response("BB").timestamp == -1 + + assert state_with_more.has_response_key_or_group("CC") + assert state_with_more.get_response("CC").timestamp == 9999 + + states_with_all = state_old.copy().assign_state(state_with_more) + + assert states_with_all.has_response_key_or_group("A") + assert states_with_all.get_response("A").timestamp == 999 + + assert not states_with_all.has_response_key_or_group("B") + assert states_with_all.get_response("B").timestamp == -1 + + assert states_with_all.has_response_key_or_group("C") + assert states_with_all.get_response("C").timestamp == 888 + + assert states_with_all.has_response_key_or_group("AA") + assert states_with_all.get_response("AA").timestamp == 1010 + + assert states_with_all.has_response_key_or_group("BB") + assert states_with_all.get_response("BB").timestamp == -1 + + assert states_with_all.has_response_key_or_group("CC") + assert states_with_all.get_response("CC").timestamp == 9999 diff --git a/tests/unit_tests/test_load_forward_model.py b/tests/unit_tests/test_load_forward_model.py index 02217bed883..fa29525bf3d 100644 --- a/tests/unit_tests/test_load_forward_model.py +++ b/tests/unit_tests/test_load_forward_model.py @@ -291,5 +291,6 @@ def test_that_the_states_are_set_correctly(): experiment=ensemble.experiment, ensemble_size=ensemble_size ) facade.load_from_forward_model(new_ensemble, realizations, 0) + new_ensemble._refresh_realization_states() assert not new_ensemble.is_initalized() assert new_ensemble.has_data()