Skip to content

Commit

Permalink
Refactor some more
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Dec 29, 2023
1 parent c6fdba5 commit 07dc011
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 65 deletions.
8 changes: 5 additions & 3 deletions src/ert/dark_storage/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ def data_for_key(
"""Returns a pandas DataFrame with the datapoints for a given key for a
given case. The row index is the realization number, and the columns are an
index over the indexes/dates"""

if key.startswith("LOG10_"):
key = key[6:]
if key in ensemble.get_summary_keyset():
data = ensemble.load_all_summary_data([key], realization_index)
data = data[key].unstack(level="Date")
elif key in ensemble.get_gen_kw_keyset():
data = ensemble.gather_gen_kw_data(key, realization_index)
data = ensemble.load_all_gen_kw_data(key.split(":")[0], realization_index)
if data.empty:
return data
return pd.DataFrame()
data = data[key].to_frame().dropna()
data.columns = pd.Index([0])
elif key in ensemble.get_gen_data_keyset():
key_parts = key.split("@")
Expand All @@ -57,7 +59,7 @@ def data_for_key(
realization_index,
).T
except (ValueError, KeyError):
data = pd.DataFrame()
return pd.DataFrame()
else:
return pd.DataFrame()

Expand Down
81 changes: 21 additions & 60 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,20 @@ def realizations_initialized(self, realizations: List[int]) -> bool:
)
return all(real in initialized_realizations for real in realizations)

def has_parameter_group(self, group: str) -> bool:
param_group_file = self.mount_point / f"realization-0/{group}.nc"
return param_group_file.exists()

def _filter_active_realizations(
self, realization_index: Optional[int] = None
) -> List[int]:
realizations = self.realization_list(RealizationStorageState.HAS_DATA)
if realization_index is not None:
if realization_index not in realizations:
raise IndexError(f"No such realization {realization_index}")
realizations = [realization_index]
return realizations

def get_summary_keyset(self) -> List[str]:
realization_folders = list(self.mount_point.glob("realization-*"))
if not realization_folders:
Expand Down Expand Up @@ -185,7 +199,7 @@ def _load_single_dataset(
def _load_dataset(
self,
group: str,
realizations: Union[int, npt.NDArray[np.int_], None],
realizations: Union[int, List[int], None],
) -> xr.Dataset:
if isinstance(realizations, int):
return self._load_single_dataset(group, realizations).isel(
Expand All @@ -201,10 +215,6 @@ def _load_dataset(
datasets = [self._load_single_dataset(group, i) for i in realizations]
return xr.combine_nested(datasets, "realizations")

def has_parameter_group(self, group: str) -> bool:
param_group_file = self.mount_point / f"realization-0/{group}.nc"
return param_group_file.exists()

def load_parameters(
self, group: str, realizations: Union[int, npt.NDArray[np.int_], None] = None
) -> xr.Dataset:
Expand All @@ -227,24 +237,17 @@ def load_responses(
assert isinstance(response, xr.Dataset)
return response

def get_active_realizations(self) -> List[int]:
return self.realization_list(RealizationStorageState.HAS_DATA)

def load_all_summary_data(
self,
keys: Optional[List[str]] = None,
realization_index: Optional[int] = None,
) -> pd.DataFrame:
realizations = self.get_active_realizations()
if realization_index is not None:
if realization_index not in realizations:
raise IndexError(f"No such realization {realization_index}")
realizations = [realization_index]

summary_keys = self.get_summary_keyset()

try:
df = self.load_responses("summary", tuple(realizations)).to_dataframe()
df = self.load_responses(
"summary", tuple(self._filter_active_realizations(realization_index))
).to_dataframe()
except (ValueError, KeyError):
return pd.DataFrame()
df = df.unstack(level="name")
Expand All @@ -259,24 +262,6 @@ def load_all_summary_data(
return df[summary_keys]
return df

def gather_summary_data(
self,
key: str,
realization_index: Optional[int] = None,
) -> Union[pd.DataFrame, pd.Series]:
data = self.load_all_summary_data([key], realization_index)
if data.empty:
return data
idx = data.index.duplicated()
if idx.any():
data = data[~idx]
logger.warning(
"The simulation data contains duplicate "
"timestamps. A possible explanation is that your "
"simulation timestep is less than a second."
)
return data.unstack(level="Realization")

def _get_gen_data_config(self, key: str) -> GenDataConfig:
config = self.experiment.response_configuration[key]
assert isinstance(config, GenDataConfig)
Expand Down Expand Up @@ -305,11 +290,7 @@ def load_gen_data(
report_step: int,
realization_index: Optional[int] = None,
) -> pd.DataFrame:
realizations = self.realization_list(RealizationStorageState.HAS_DATA)
if realization_index is not None:
if realization_index not in realizations:
raise IndexError(f"No such realization {realization_index}")
realizations = [realization_index]
realizations = self._filter_active_realizations(realization_index)
try:
vals = self.load_responses(key, tuple(realizations)).sel(
report_step=report_step, drop=True
Expand All @@ -324,14 +305,8 @@ def load_gen_data(
)

def get_gen_kw_keyset(self) -> List[str]:
gen_kw_keys = [
k
for k, v in self.experiment.parameter_info.items()
if "_ert_kind" in v and v["_ert_kind"] == "GenKwConfig"
]

gen_kw_list = []
for key in gen_kw_keys:
for key in self.experiment.parameter_info:
gen_kw_config = self.experiment.parameter_configuration[key]
assert isinstance(gen_kw_config, GenKwConfig)

Expand All @@ -343,20 +318,6 @@ def get_gen_kw_keyset(self) -> List[str]:

return sorted(gen_kw_list, key=lambda k: k.lower())

def gather_gen_kw_data(
self,
key: str,
realization_index: Optional[int],
) -> pd.DataFrame:
try:
data = self.load_all_gen_kw_data(
key.split(":")[0],
realization_index,
)
return data[key].to_frame().dropna()
except KeyError:
return pd.DataFrame()

def load_all_gen_kw_data(
self,
group: Optional[str] = None,
Expand Down Expand Up @@ -402,7 +363,7 @@ def load_all_gen_kw_data(
for key in gen_kws:
try:
ds = self.load_parameters(
key.name, realizations, var="transformed_values"
key.name, list(realizations), var="transformed_values"
)
ds["names"] = np.char.add(f"{key.name}:", ds["names"].astype(np.str_))
df = ds.to_dataframe().unstack(level="names")
Expand Down
27 changes: 27 additions & 0 deletions tests/unit_tests/dark_storage/test_http_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid

import pandas as pd
import pytest
from numpy.testing import assert_array_equal
from requests import Response

Expand Down Expand Up @@ -229,3 +230,29 @@ def test_get_record_labels(poly_example_tmp_dir, dark_storage_client):
labels = resp.json()

assert labels == []


@pytest.mark.parametrize(
"coeffs",
[
"COEFFS:a",
"COEFFS:b",
"COEFFS:c",
],
)
def test_get_coeffs_records(poly_example_tmp_dir, dark_storage_client, coeffs):
resp: Response = dark_storage_client.get("/experiments")
answer_json = resp.json()
ensemble_id = answer_json[0]["ensemble_ids"][0]

resp: Response = dark_storage_client.get(
f"/ensembles/{ensemble_id}/records/{coeffs}/",
headers={"accept": "application/x-parquet"},
)

stream = io.BytesIO(resp.content)
dataframe = pd.read_parquet(stream)

assert all(dataframe.index.values == [1, 2, 4])
assert dataframe.index.name == "Realization"
assert dataframe.shape == tuple([3, 1])
4 changes: 2 additions & 2 deletions tests/unit_tests/test_libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_keyword_type_checks_missing_key(snake_oil_default_storage):
def test_data_fetching_missing_key(empty_case):
data = [
empty_case.load_all_summary_data(["nokey"]),
empty_case.gather_gen_kw_data("nokey", None),
empty_case.load_all_gen_kw_data("nokey", None),
]

for dataframe in data:
Expand Down Expand Up @@ -293,7 +293,7 @@ def test_gen_kw_collector(snake_oil_default_storage, snapshot):
snapshot.assert_match(data.round(6).to_csv(), "gen_kw_collector_3.csv")

non_existing_realization_index = 150
with pytest.raises(KeyError):
with pytest.raises(tuple([IndexError, KeyError])):
_ = snake_oil_default_storage.load_all_gen_kw_data(
"SNAKE_OIL_PARAM",
realization_index=non_existing_realization_index,
Expand Down

0 comments on commit 07dc011

Please sign in to comment.