Skip to content

Commit

Permalink
Add observation traversal logic to LocalExperiment
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen authored and yngve-sk committed Jun 14, 2024
1 parent 77fa511 commit 08876f0
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 48 deletions.
46 changes: 24 additions & 22 deletions src/ert/gui/ertwidgets/storage_info_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,16 @@ def _currentItemChanged(
return

observation_label = selected.data(0, Qt.ItemDataRole.DisplayRole)
observations_dict = self._ensemble.experiment.observations

self._figure.clear()
ax = self._figure.add_subplot(111)
ax.set_title(observation_name)
ax.grid(True)

observation_ds = observations_dict[observation_name]

response_name = observation_ds.attrs["response"]
response_name, response_type = self._ensemble.experiment.response_info_for_obs(
observation_name
)
observation_ds = self._ensemble.experiment.get_single_obs_ds(observation_name)
response_ds = self._ensemble.load_responses(
response_name,
)
Expand Down Expand Up @@ -256,40 +256,42 @@ def _currentTabChanged(self, index: int) -> None:
self._figure.clear()
self._canvas.draw()

observations_dict = self._ensemble.experiment.observations
for obs_name, obs_ds in observations_dict.items():
response_name = obs_ds.attrs["response"]
if response_name == "summary":
name = obs_ds.name.data[0]
else:
name = response_name
exp = self._ensemble.experiment
for obs_name in exp.observation_keys:
response_name, response_type = exp.response_info_for_obs(obs_name)
obs_ds = exp.get_single_obs_ds(obs_name)

match_list = self._observations_tree_widget.findItems(
name, Qt.MatchFlag.MatchExactly
response_name, Qt.MatchFlag.MatchExactly
)
if len(match_list) == 0:
root = QTreeWidgetItem(self._observations_tree_widget, [name])
root = QTreeWidgetItem(
self._observations_tree_widget, [response_name]
)
else:
root = match_list[0]

if "time" in obs_ds.coords:
for t in obs_ds.time:
for t in obs_ds.dropna("time").time:
QTreeWidgetItem(
root,
[str(np.datetime_as_string(t.values, unit="D")), obs_name],
[
str(np.datetime_as_string(t.values, unit="D")),
obs_name,
],
)
elif "index" in obs_ds.coords:
for t in obs_ds.index:
for t in obs_ds.dropna("index").index:
QTreeWidgetItem(root, [str(t.data), obs_name])

self._observations_tree_widget.sortItems(0, Qt.SortOrder.AscendingOrder)

for i in range(self._observations_tree_widget.topLevelItemCount()):
if self._observations_tree_widget.topLevelItem(i).childCount() > 0:
self._observations_tree_widget.setCurrentItem(
self._observations_tree_widget.topLevelItem(i).child(0)
)
break
for i in range(self._observations_tree_widget.topLevelItemCount()):
if self._observations_tree_widget.topLevelItem(i).childCount() > 0:
self._observations_tree_widget.setCurrentItem(
self._observations_tree_widget.topLevelItem(i).child(0)
)
break

@Slot(Ensemble)
def setEnsemble(self, ensemble: Ensemble) -> None:
Expand Down
24 changes: 24 additions & 0 deletions src/ert/storage/ensure_correct_xr_coordinate_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import xarray as xr


def ensure_correct_coordinate_order(ds: xr.Dataset) -> xr.Dataset:
"""
Ensures correct coordinate order or response/param dataset.
Slightly less performant than not doing it, but ensure the
correct coordinate order is applied when doing .to_dataframe().
It is possible to omit using this and instead pass in the correct
dim order when doing .to_dataframe(), which is always the same as
the .dims of the first data var of this dataset.
"""
# Just to make the order right when
# doing .to_dataframe()
# (it seems notoriously hard to tell xarray to just reorder
# the dimensions/coordinate labels)
data_vars = list(ds.data_vars.keys())

# We assume only data vars with the same dimensions,
# i.e., (realization, *index) for all of them.
dim_order_of_first_var = ds[data_vars[0]].dims
return ds[[*dim_order_of_first_var, *data_vars]].sortby(
dim_order_of_first_var[0] # "realization" / "realizations"
)
28 changes: 3 additions & 25 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ert.storage.mode import BaseMode, Mode, require_write

from ..config import GenDataConfig, ResponseTypes
from .ensure_correct_xr_coordinate_order import ensure_correct_coordinate_order
from .realization_storage_state import RealizationStorageState

if TYPE_CHECKING:
Expand Down Expand Up @@ -1240,29 +1241,6 @@ def get_observations_and_responses(

return ObservationsAndResponsesData(sorted_long_np)

@staticmethod
def _ensure_correct_coordinate_order(ds: xr.Dataset) -> xr.Dataset:
"""
Ensures correct coordinate order or response/param dataset.
Slightly less performant than not doing it, but ensure the
correct coordinate order is applied when doing .to_dataframe().
It is possible to omit using this and instead pass in the correct
dim order when doing .to_dataframe(), which is always the same as
the .dims of the first data var of this dataset.
"""
# Just to make the order right when
# doing .to_dataframe()
# (it seems notoriously hard to tell xarray to just reorder
# the dimensions/coordinate labels)
data_vars = list(ds.data_vars.keys())

# We assume only data vars with the same dimensions,
# i.e., (realization, *index) for all of them.
dim_order_of_first_var = ds[data_vars[0]].dims
return ds[[*dim_order_of_first_var, *data_vars]].sortby(
dim_order_of_first_var[0] # "realization" / "realizations"
)

def _unify_datasets(
self,
groups: List[str],
Expand Down Expand Up @@ -1299,7 +1277,7 @@ def _unify_datasets(
new_combined = old_combined.merge(new_combined)
os.remove(combined_ds_path)

new_combined = self._ensure_correct_coordinate_order(new_combined)
new_combined = ensure_correct_coordinate_order(new_combined)

if not new_combined:
raise ValueError("Unified dataset somehow ended up empty")
Expand Down Expand Up @@ -1353,7 +1331,7 @@ def unify_responses(self, key: Optional[str] = None) -> None:
new_combined_ds = xr.concat(to_concat, dim="name").sortby(
["realization", "name"]
)
new_combined_ds = self._ensure_correct_coordinate_order(new_combined_ds)
new_combined_ds = ensure_correct_coordinate_order(new_combined_ds)

if has_existing_combined:
old_combined = xr.load_dataset(self._path / "gen_data.nc")
Expand Down
34 changes: 33 additions & 1 deletion src/ert/storage/local_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from uuid import UUID

import numpy as np
Expand All @@ -25,6 +25,8 @@
from ert.config.response_config import ResponseConfig
from ert.storage.mode import BaseMode, Mode, require_write

from .ensure_correct_xr_coordinate_order import ensure_correct_coordinate_order

if TYPE_CHECKING:
from ert.config.parameter_config import ParameterConfig
from ert.run_models.run_arguments import (
Expand Down Expand Up @@ -317,3 +319,33 @@ def observations_for_response(self, response_name: str) -> xr.Dataset:
return ds.sel(name=response_name)

return xr.Dataset()

@cached_property
def _obs_key_to_response_name_and_type(self) -> Dict[str, Tuple[str, str]]:
obs_to_response: Dict[str, Tuple[str, str]] = {}

for response_type, obs_ds_for_response in self.observations.items():
for response_name, ds_for_response in obs_ds_for_response.groupby("name"):
for obs_name in ds_for_response.dropna("obs_name", how="all")[
"obs_name"
].values:
obs_to_response[obs_name] = (response_name, response_type)

return obs_to_response

def response_info_for_obs(self, obs_name: str) -> Tuple[str, str]:
"""
Returns a tuple containing (response_name, response_type)
"""
return self._obs_key_to_response_name_and_type[obs_name]

def get_single_obs_ds(self, obs_name: str) -> xr.Dataset:
response_name, response_type = self._obs_key_to_response_name_and_type[obs_name]

# Note: Does not dropna on index
# "time" for summary, "index", "report_step" for gen_data
return ensure_correct_coordinate_order(
self.observations[response_type]
.sel(obs_name=obs_name, drop=True)
.sel(name=response_name, drop=True)
)
62 changes: 62 additions & 0 deletions tests/unit_tests/storage/test_ensemble_data_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def _create_gen_data_config_ds_and_obs(
.to_xarray()
)

gen_data_obs.attrs["response"] = "gen_data"

return gen_data_configs, gen_data_ds, gen_data_obs


Expand Down Expand Up @@ -135,12 +137,72 @@ def _create_summary_config_ds_and_obs(
.set_index(["name", "obs_name", "time"])
.to_xarray()
)
summary_obs_ds.attrs["response"] = "summary"

summary_ds = summary_df.set_index(["name", "time"]).to_xarray()

return summary_config, summary_ds, summary_obs_ds


@pytest.mark.usefixtures("use_tmpdir")
@pytest.mark.parametrize(
("num_reals, num_responses, num_obs, num_indices, num_report_steps"),
[
(100, 1, 1, 1, 1),
(100, 5, 3, 2, 10),
(10, 50, 100, 10, 200),
],
)
def test_that_observation_getters_from_experiment_match_expected_data(
tmpdir, num_reals, num_responses, num_obs, num_indices, num_report_steps
):
gen_data_configs, gen_data_ds, gen_data_obs = _create_gen_data_config_ds_and_obs(
num_responses, num_obs, num_indices, num_report_steps
)

summary_config, summary_ds, summary_obs = _create_summary_config_ds_and_obs(
num_responses, num_indices * num_report_steps, num_obs
)

with open_storage(tmpdir, "w") as s:
exp = s.create_experiment(
responses=[*gen_data_configs],
observations={"gen_data": gen_data_obs, "summary": summary_obs},
)

assert exp._obs_key_to_response_name_and_type == {
**{
f"gen_obs_{i}": (
f"gen_data_{i%num_responses}",
"gen_data",
)
for i in range(num_obs)
},
**{
f"sum_obs_{i}": (f"sum_key_{i%num_responses}", "summary")
for i in range(num_obs)
},
}

for i in range(num_obs):
assert (
gen_data_obs.sel(obs_name=f"gen_obs_{i}", drop=True)
.dropna("name", how="all")
.squeeze("name", drop=True)
.to_dataframe()
.dropna()
.equals(exp.get_single_obs_ds(f"gen_obs_{i}").to_dataframe().dropna())
)
assert (
summary_obs.sel(obs_name=f"sum_obs_{i}", drop=True)
.dropna("name", how="all")
.squeeze("name", drop=True)
.to_dataframe()
.dropna()
.equals(exp.get_single_obs_ds(f"sum_obs_{i}").to_dataframe().dropna())
)


@pytest.mark.usefixtures("use_tmpdir")
@pytest.mark.parametrize(
("num_reals, num_gen_data, num_gen_obs, num_indices, num_report_steps"),
Expand Down

0 comments on commit 08876f0

Please sign in to comment.