Skip to content

Commit

Permalink
(fixups) Refactor combined response part to be generic
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Jun 21, 2024
1 parent cacc2b0 commit be88a6c
Showing 1 changed file with 61 additions and 67 deletions.
128 changes: 61 additions & 67 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
List,
Literal,
Optional,
Set,
Tuple,
Union,
)
Expand Down Expand Up @@ -1547,25 +1548,50 @@ def _refresh_for_gendata_not_combined(key: str):
# once per combined dataset, per call to this
_cached_open_combined_datasets: Dict[str, xr.Dataset] = {}

def _refresh_for_gendata_combined(key: str):
_cached_open_combined_datasets["gen_data"] = combined_ds = (
_cached_open_combined_datasets["gen_data"]
if "gen_data" in _cached_open_combined_datasets
else self._load_combined_response_dataset("gen_data")
def _refresh_for_responses_combined(response_type, response_keys: Set[str]):
_cached_open_combined_datasets[response_type] = combined_ds = (
_cached_open_combined_datasets[response_type]
if response_type in _cached_open_combined_datasets
else self._load_combined_response_dataset(response_type)
)

gen_data_ds = combined_ds.sel(name=key).dropna("realization")
ds_time_of_birth = os.path.getctime(self._path / f"{response_type}.nc")

for realization_index in range(self.ensemble_size):
real_state = states.get_single_realization_state(realization_index)

real_state.set_response(
key,
value=realization_index in gen_data_ds["realization"],
response_type="gen_data",
source=self._path / "gen_data.nc",
)
if (
real_state.has_response_key_or_group(response_type)
and real_state.get_response(response_type).timestamp
== ds_time_of_birth
):
continue

if realization_index not in combined_ds["realization"]:
for expected_key in [response_key, *response_keys]:
real_state.set_response(
key=expected_key,
value=False,
response_type=response_type,
source=self._path / f"{response_type}.nc",
)

def _refresh_for_summary_not_combined():
continue

summary_4real = combined_ds.sel(realization=realization_index).dropna(
"name", how="all"
)["name"]
found_keys = summary_4real["name"].data

for expected_key in response_keys:
real_state.set_response(
expected_key,
value=expected_key in found_keys,
response_type=response_type,
source=self._path / f"{response_type}.nc",
)

def _refresh_for_summary_not_combined(expected_keys: Set[str]):
for realization_index in range(self.ensemble_size):
ds_path = self._realization_dir(realization_index) / "summary.nc"
realization_state = states.get_single_realization_state(
Expand All @@ -1581,80 +1607,48 @@ def _refresh_for_summary_not_combined():
continue

ds = xr.open_dataset(ds_path)
for smry_key in expected_smry_keys:
for smry_key in expected_keys:
realization_state.set_response(
key=smry_key,
value=smry_key in ds["name"],
response_type="summary",
source=ds_path,
)
else:
for smry_key in expected_smry_keys:
for smry_key in expected_keys:
realization_state.set_response(
key=smry_key,
value=False,
response_type="summary",
source=ds_path,
)

def _refresh_for_summary_combined():
_cached_open_combined_datasets["summary"] = summary_ds = (
_cached_open_combined_datasets["summary"]
if "summary" in _cached_open_combined_datasets
else self._load_combined_response_dataset("summary")
)

summary_tob = os.path.getctime(self._path / "summary.nc")

for realization_index in range(self.ensemble_size):
real_state = states.get_single_realization_state(realization_index)

if (
real_state.has_response_key_or_group("summary")
and real_state.get_response("summary").timestamp == summary_tob
):
continue

if realization_index not in summary_ds["realization"]:
for expected_key in ["summary", *expected_smry_keys]:
real_state.set_response(
key=expected_key,
value=False,
response_type="summary",
source=self._path / "summary.nc",
)

continue

summary_4real = summary_ds.sel(realization=realization_index).dropna(
"name", how="all"
)["name"]
found_keys = summary_4real["name"].data

for expected_key in expected_smry_keys:
real_state.set_response(
expected_key,
value=expected_key in found_keys,
response_type="summary",
source=self._path / "summary.nc",
)
keys_per_response_type: Dict[str, Set[str]] = {}

for response_config in response_configs:
if isinstance(response_config, GenDataConfig):
response_key = response_config.name
if not has_combined_gendata:
_refresh_for_gendata_not_combined(response_key)
else:
_refresh_for_gendata_combined(response_key)
if "gen_data" not in keys_per_response_type:
keys_per_response_type["gen_data"] = set()

keys_per_response_type["gen_data"].add(response_config.name)

elif isinstance(response_config, SummaryConfig):
assert isinstance(response_config, SummaryConfig)
expected_smry_keys = response_config.keys

if not has_combined_summary:
_refresh_for_summary_not_combined()
else:
_refresh_for_summary_combined()
if "summary" not in keys_per_response_type:
keys_per_response_type["summary"] = set()

keys_per_response_type["summary"].update(response_config.keys)

for response_type, expected_keys in keys_per_response_type.items():
if self.has_combined_response_dataset(response_type):
_refresh_for_responses_combined(response_type, expected_keys)
elif response_type == "gen_data":
# Not combined and one file per response key
for k in expected_keys:
_refresh_for_gendata_not_combined(k)
else:
# Not combined and one file with multiple response keys
_refresh_for_summary_not_combined(expected_keys)

self._realization_states = old_states.copy().assign_states(states)

Expand Down

0 comments on commit be88a6c

Please sign in to comment.