Skip to content

Commit

Permalink
Some smaller fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Jun 19, 2024
1 parent 200c382 commit 820589c
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 93 deletions.
1 change: 1 addition & 0 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def load_from_forward_model(
run_context.run_args,
run_context.mask,
)
run_context.ensemble.unify_responses()
_logger.debug(
f"load_from_forward_model() time_used {(time.perf_counter() - t):.4f}s"
)
Expand Down
151 changes: 102 additions & 49 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import xarray as xr
from pandas import DataFrame
from pydantic import BaseModel
from typing_extensions import deprecated
from typing_extensions import Self, deprecated

from ert.config.gen_kw_config import GenKwConfig
from ert.config.observations import ObservationsIndices
Expand Down Expand Up @@ -61,6 +61,14 @@ def has_all_responses(self) -> bool:
def has_all_parameters(self) -> bool:
return all(self.has_parameter_group.values())

def copy(self) -> _SingleRealizationState:
return _SingleRealizationState(
response_data_latest_birth_time=self.response_data_latest_birth_time,
parameter_data_latest_birth_time=self.parameter_data_latest_birth_time,
has_response={**self.has_response},
has_parameter_group={**self.has_parameter_group},
)


@dataclasses.dataclass
class _RealizationStates:
Expand Down Expand Up @@ -98,15 +106,23 @@ def ensure_state_exists(self, realization: int) -> _SingleRealizationState:

return self.realizations[realization]

def copy(self) -> _RealizationStates:
return _RealizationStates(
response_data_latest_birth_time=self.response_data_latest_birth_time,
parameter_data_latest_birth_time=self.parameter_data_latest_birth_time,
realizations={k: v.copy() for k, v in self.realizations.items()},
)

def assign_parameter_states(
self, src: _RealizationStates, parameter_group: Optional[str] = None
) -> None:
) -> Self:
self.parameter_data_latest_birth_time = src.parameter_data_latest_birth_time

for realization_index, src_real in src.realizations.items():
my_real = self.ensure_state_exists(realization_index)
my_real.parameter_data_latest_birth_time = (
src_real.parameter_data_latest_birth_time
my_real.parameter_data_latest_birth_time = max(
my_real.parameter_data_latest_birth_time,
src_real.parameter_data_latest_birth_time,
)

if parameter_group is not None:
Expand All @@ -116,21 +132,26 @@ def assign_parameter_states(
else:
my_real.has_parameter_group = {**src_real.has_parameter_group}

return self

def assign_response_states(
self, src: _RealizationStates, response_key: Optional[str] = None
) -> None:
) -> Self:
self.response_data_latest_birth_time = src.response_data_latest_birth_time

for realization_index, src_real in src.realizations.items():
my_real = self.ensure_state_exists(realization_index)
my_real.response_data_latest_birth_time = (
src_real.response_data_latest_birth_time
my_real.response_data_latest_birth_time = max(
my_real.response_data_latest_birth_time,
src_real.response_data_latest_birth_time,
)

if response_key is not None:
my_real.has_response[response_key] = src_real.has_response[response_key]
else:
my_real.has_response = {**src_real.has_response}
my_real.has_response.update({**src_real.has_response})

return self


class _Index(BaseModel):
Expand Down Expand Up @@ -246,6 +267,9 @@ def create_realization_dir(realization: int) -> Path:
self._realization_dir = create_realization_dir
self._realization_states = _RealizationStates()

def on_experiment_initialized(self):
self._refresh_realization_states()

@classmethod
def create(
cls,
Expand Down Expand Up @@ -492,12 +516,6 @@ def _responses_exist_for_realization(
otherwise, `False`.
"""

if key in {"summary", "gen_data"}:
raise KeyError(
"Checking if used,.. if it is is that needs to"
"be handled in self._realization_states"
)

if not key:
return all(
self._realization_states.has_response(realization, response)
Expand Down Expand Up @@ -1432,7 +1450,7 @@ def unify_responses(self, key: Optional[str] = None) -> None:
"realization",
)

self._refresh_realization_states()
self._refresh_realization_states_for_responses(response_key=key)

def unify_parameters(self, key: Optional[str] = None) -> None:
self._unify_datasets(
Expand Down Expand Up @@ -1477,20 +1495,17 @@ def get_response_state(
for e in self.experiment.response_configuration
)

def _refresh_realization_states(self) -> None:
# For every realization, we want know
# Does it have all responses?
# Which responses does it have
current_states = self._realization_states
def _refresh_realization_states_for_parameters(self) -> None:
old_states = self._realization_states
states = _RealizationStates()

for parameter_group in self.experiment.parameter_configuration:
if self.has_combined_parameter_dataset(parameter_group):
birth_time = os.path.getctime(self._path / f"{parameter_group}.nc")

if current_states.parameter_data_latest_birth_time == birth_time:
if old_states.parameter_data_latest_birth_time == birth_time:
states.assign_parameter_states(
src=current_states, parameter_group=parameter_group
src=old_states, parameter_group=parameter_group
)
else:
states.parameter_data_latest_birth_time = birth_time
Expand All @@ -1517,46 +1532,55 @@ def _refresh_realization_states(self) -> None:
)

real_state.has_parameter_group[parameter_group] = ds_path.exists()

if real_state.has_parameter_group[parameter_group]:
real_state.parameter_data_latest_birth_time = os.path.getctime(
ds_path
)

self._realization_states = old_states.copy().assign_parameter_states(states)

def _refresh_realization_states_for_responses(
self, response_key: Optional[str] = None
) -> None:
old_states = self._realization_states
states = _RealizationStates()

# If it has gen data, check all names w/ non nan values
# if it has combine
has_combined_summary = self.has_combined_response_dataset("summary")
has_combined_gendata = self.has_combined_response_dataset("gen_data")

for response_config in self.experiment.response_configuration.values():
response_configs = [
c
for c in self.experiment.response_configuration.values()
if response_key is None or c.name == response_key
]
for response_config in response_configs:
if isinstance(response_config, GenDataConfig):
response_key = response_config.name
if not has_combined_gendata:
if self.has_combined_response_dataset(response_key):
birth_time = os.path.getctime(self._path / "gen_data.nc")
for realization_index in range(self.ensemble_size):
ds_path = (
self._realization_dir(realization_index)
/ f"{response_key}.nc"
)
realization_state = states.ensure_state_exists(
realization_index
)

if current_states.response_data_latest_birth_time == birth_time:
states.assign_response_states(
current_states, response_key=response_key
)
continue
else:
for realization_index in range(self.ensemble_size):
ds_path = (
self._realization_dir(realization_index)
/ f"{response_key}.nc"
)
realization_state = states.ensure_state_exists(
realization_index
if ds_path.exists():
realization_state.response_data_latest_birth_time = max(
realization_state.response_data_latest_birth_time,
os.path.getctime(ds_path),
)
realization_state.has_response[response_key] = True
realization_state.has_response["gen_data"] = True
else:
realization_state.has_response[response_key] = False

if "gen_data" not in realization_state.has_response:
realization_state.has_response["gen_data"] = False

if ds_path.exists():
realization_state.response_data_latest_birth_time = max(
realization_state.response_data_latest_birth_time,
os.path.getctime(ds_path),
)
realization_state.has_response[response_key] = True
else:
realization_state.has_response[response_key] = False
else:
combined_ds = self._load_combined_response_dataset("gen_data")
gen_data_tob = os.path.getctime(self._path / "gen_data.nc")
Expand All @@ -1575,6 +1599,10 @@ def _refresh_realization_states(self) -> None:
gen_data_tob, real_state.response_data_latest_birth_time
)

states.response_data_latest_birth_time = max(
gen_data_tob, states.response_data_latest_birth_time
)

elif isinstance(response_config, SummaryConfig):
smry_config = self.experiment.response_configuration["summary"]
assert isinstance(smry_config, SummaryConfig)
Expand All @@ -1591,6 +1619,13 @@ def _refresh_realization_states(self) -> None:

if ds_path.exists():
realization_state.has_response["summary"] = True

if (
realization_state.response_data_latest_birth_time
== os.path.getctime(ds_path)
):
continue

realization_state.response_data_latest_birth_time = max(
realization_state.response_data_latest_birth_time,
os.path.getctime(ds_path),
Expand All @@ -1613,9 +1648,19 @@ def _refresh_realization_states(self) -> None:

for realization_index in range(self.ensemble_size):
real_state = states.ensure_state_exists(realization_index)

if real_state.response_data_latest_birth_time == summary_tob:
continue

if realization_index not in summary_ds["realization"]:
for expected_key in ["summary", *expected_smry_keys]:
real_state.has_response[expected_key] = False

continue

summary_4real = summary_ds.sel(
realization=realization_index
).dropna("name")["name"]
).dropna("name", how="all")["name"]
found_keys = summary_4real["name"].data

for expected_key in expected_smry_keys:
Expand All @@ -1629,4 +1674,12 @@ def _refresh_realization_states(self) -> None:
real_state.response_data_latest_birth_time, summary_tob
)

self._realization_states = states
states.response_data_latest_birth_time = max(
summary_tob, states.response_data_latest_birth_time
)

self._realization_states = old_states.copy().assign_response_states(states)

def _refresh_realization_states(self, response_key: Optional[str] = None) -> None:
self._refresh_realization_states_for_responses(response_key)
self._refresh_realization_states_for_parameters()
3 changes: 2 additions & 1 deletion src/ert/storage/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def refresh(self) -> None:
self._experiments = self._load_experiments()

for ens in self._ensembles.values():
ens._refresh_realization_states()
ens.on_experiment_initialized()

def get_experiment(self, uuid: UUID) -> LocalExperiment:
"""
Expand Down Expand Up @@ -407,6 +407,7 @@ def create_ensemble(
name=str(name),
prior_ensemble_id=prior_ensemble_id,
)

if prior_ensemble:
for realization, state in enumerate(prior_ensemble.get_ensemble_state()):
if state in [
Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/scenarios/test_summary_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def create_responses(config_file, prior_ensemble, response_times):
facade.load_from_forward_model(
prior_ensemble, [True] * facade.get_ensemble_size(), 0
)
prior_ensemble.unify_responses()


def test_that_reading_matching_time_is_ok(ert_config, storage, prior_ensemble):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def test_migrate_summary(data, forecast, time_map, tmp_path):
]
)
ensemble = experiment.create_ensemble(name="default_0", ensemble_size=5)
ensemble._refresh_realization_states()

bf._migrate_summary(ensemble, forecast, time_map)
ensemble.unify_responses()
Expand Down
Loading

0 comments on commit 820589c

Please sign in to comment.