Skip to content

Commit

Permalink
(maybe) Slightly implicit approach to updating state
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Jun 25, 2024
1 parent edc008a commit 2d34a27
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
from pydantic import BaseModel
from typing_extensions import deprecated

from ert.config import GenDataConfig, ResponseTypes, SummaryConfig
from ert.config.gen_kw_config import GenKwConfig
from ert.config.observations import ObservationsIndices
from ert.storage.mode import BaseMode, Mode, require_write

from ..config import GenDataConfig, ResponseTypes, SummaryConfig
from .ensure_correct_xr_coordinate_order import ensure_correct_coordinate_order
from .realization_state import _MultiRealizationStateDict
from .realization_storage_state import RealizationStorageState
Expand Down Expand Up @@ -177,14 +177,16 @@ def create_realization_dir(realization: int) -> Path:

self._realization_dir = create_realization_dir
self._realization_states = _MultiRealizationStateDict()
self.__realization_states_need_refresh = True

def on_experiment_initialized(self) -> None:
"""
Executes logic that depends on the experiment of the ensemble to exist.
For example, if some logic needs to traverse response/parameter configs,
the experiment of this ensemble must be initialized before running that logic.
"""
self._refresh_realization_states()

self.__realization_states_need_refresh = True

@classmethod
def create(
Expand Down Expand Up @@ -280,6 +282,12 @@ def parent(self) -> Optional[UUID]:
def experiment(self) -> LocalExperiment:
return self._storage.get_experiment(self.experiment_id)

def _refresh_realization_state_if_needed(self):

Check failure on line 285 in src/ert/storage/local_ensemble.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
if self.__realization_states_need_refresh:
self._refresh_realization_states()

self.__realization_states_need_refresh = False

def get_realization_mask_without_parent_failure(self) -> npt.NDArray[np.bool_]:
"""
Mask array indicating realizations without a parent failure.
Expand Down Expand Up @@ -385,6 +393,7 @@ def _parameters_exist_for_realization(self, realization: int) -> bool:
if not self.experiment.parameter_configuration:
return True

self._refresh_realization_state_if_needed()

Check failure on line 396 in src/ert/storage/local_ensemble.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "_refresh_realization_state_if_needed" in typed context
return all(
self._realization_states.has_parameter_group(realization, parameter)
for parameter in self.experiment.parameter_configuration
Expand Down Expand Up @@ -434,6 +443,7 @@ def _responses_exist_for_realization(
otherwise, `False`.
"""

self._refresh_realization_state_if_needed()

Check failure on line 446 in src/ert/storage/local_ensemble.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Call to untyped function "_refresh_realization_state_if_needed" in typed context
if not key:
return all(
self._realization_states.has_response(realization, response)
Expand All @@ -454,6 +464,8 @@ def is_initalized(self) -> List[int]:
Returns the realization numbers with parameters
"""

self._refresh_realization_state_if_needed()

return list(
i
for i in range(self.ensemble_size)
Expand Down Expand Up @@ -1061,6 +1073,7 @@ def save_parameters(
dataset = dataset.expand_dims(realizations=[realization])

dataset.to_netcdf(path, engine="scipy")
self.__realization_states_need_refresh = True

@require_write
def save_response(self, group: str, data: xr.Dataset, realization: int) -> None:
Expand Down Expand Up @@ -1095,6 +1108,7 @@ def save_response(self, group: str, data: xr.Dataset, realization: int) -> None:
Path.mkdir(output_path, parents=True, exist_ok=True)

data.to_netcdf(output_path / f"{group}.nc", engine="scipy")
self.__realization_states_need_refresh = True

def calculate_std_dev_for_parameter(self, parameter_group: str) -> xr.Dataset:
if not parameter_group in self.experiment.parameter_configuration:
Expand Down

0 comments on commit 2d34a27

Please sign in to comment.