Skip to content

Commit

Permalink
Move some functions around
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Dec 29, 2023
1 parent 07dc011 commit 6d6e86f
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 71 deletions.
5 changes: 0 additions & 5 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import numpy as np
<<<<<<< HEAD
import pandas as pd
import xarray as xr
=======
>>>>>>> More refactor
from deprecation import deprecated
from pandas import DataFrame
from resdata.grid import Grid
Expand Down
105 changes: 52 additions & 53 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ def state_map(self) -> List[RealizationStorageState]:
def experiment(self) -> Union[LocalExperimentReader, LocalExperimentAccessor]:
return self._storage.get_experiment(self.experiment_id)

@property
def is_initalized(self) -> bool:
return RealizationStorageState.INITIALIZED in self.state_map or self.has_data

@property
def has_data(self) -> bool:
return RealizationStorageState.HAS_DATA in self.state_map

def close(self) -> None:
self.sync()

Expand All @@ -135,14 +143,6 @@ def _load_state_map(self) -> List[RealizationStorageState]:
RealizationStorageState.UNDEFINED for _ in range(self.ensemble_size)
]

@property
def is_initalized(self) -> bool:
return RealizationStorageState.INITIALIZED in self.state_map or self.has_data

@property
def has_data(self) -> bool:
return RealizationStorageState.HAS_DATA in self.state_map

def realizations_initialized(self, realizations: List[int]) -> bool:
initialized_realizations = set(
self.realization_list(RealizationStorageState.INITIALIZED)
Expand All @@ -163,6 +163,12 @@ def _filter_active_realizations(
realizations = [realization_index]
return realizations

def realization_list(self, state: RealizationStorageState) -> List[int]:
"""
Will return list of realizations with state == the specified state.
"""
return [i for i, s in enumerate(self._state_map) if s == state]

def get_summary_keyset(self) -> List[str]:
realization_folders = list(self.mount_point.glob("realization-*"))
if not realization_folders:
Expand All @@ -175,11 +181,41 @@ def get_summary_keyset(self) -> List[str]:
keys = sorted(response["name"].values)
return keys

def realization_list(self, state: RealizationStorageState) -> List[int]:
"""
Will return list of realizations with state == the specified state.
"""
return [i for i, s in enumerate(self._state_map) if s == state]
def _get_gen_data_config(self, key: str) -> GenDataConfig:
config = self.experiment.response_configuration[key]
assert isinstance(config, GenDataConfig)
return config

def get_gen_data_keyset(self) -> List[str]:
keylist = [
k
for k, v in self.experiment.response_info.items()
if "_ert_kind" in v and v["_ert_kind"] == "GenDataConfig"
]

gen_data_list = []
for key in keylist:
gen_data_config = self._get_gen_data_config(key)
if gen_data_config.report_steps is None:
gen_data_list.append(f"{key}@0")
else:
for report_step in gen_data_config.report_steps:
gen_data_list.append(f"{key}@{report_step}")
return sorted(gen_data_list, key=lambda k: k.lower())

def get_gen_kw_keyset(self) -> List[str]:
gen_kw_list = []
for key in self.experiment.parameter_info:
gen_kw_config = self.experiment.parameter_configuration[key]
assert isinstance(gen_kw_config, GenKwConfig)

for keyword in [e.name for e in gen_kw_config.transfer_functions]:
gen_kw_list.append(f"{key}:{keyword}")

if gen_kw_config.shouldUseLogScale(keyword):
gen_kw_list.append(f"LOG10_{key}:{keyword}")

return sorted(gen_kw_list, key=lambda k: k.lower())

def _load_single_dataset(
self,
Expand All @@ -199,7 +235,7 @@ def _load_single_dataset(
def _load_dataset(
self,
group: str,
realizations: Union[int, List[int], None],
realizations: Union[int, npt.NDArray[np.int_], None],
) -> xr.Dataset:
if isinstance(realizations, int):
return self._load_single_dataset(group, realizations).isel(
Expand Down Expand Up @@ -262,28 +298,6 @@ def load_all_summary_data(
return df[summary_keys]
return df

def _get_gen_data_config(self, key: str) -> GenDataConfig:
config = self.experiment.response_configuration[key]
assert isinstance(config, GenDataConfig)
return config

def get_gen_data_keyset(self) -> List[str]:
keylist = [
k
for k, v in self.experiment.response_info.items()
if "_ert_kind" in v and v["_ert_kind"] == "GenDataConfig"
]

gen_data_list = []
for key in keylist:
gen_data_config = self._get_gen_data_config(key)
if gen_data_config.report_steps is None:
gen_data_list.append(f"{key}@0")
else:
for report_step in gen_data_config.report_steps:
gen_data_list.append(f"{key}@{report_step}")
return sorted(gen_data_list, key=lambda k: k.lower())

def load_gen_data(
self,
key: str,
Expand All @@ -304,20 +318,6 @@ def load_gen_data(
columns=realizations,
)

def get_gen_kw_keyset(self) -> List[str]:
gen_kw_list = []
for key in self.experiment.parameter_info:
gen_kw_config = self.experiment.parameter_configuration[key]
assert isinstance(gen_kw_config, GenKwConfig)

for keyword in [e.name for e in gen_kw_config.transfer_functions]:
gen_kw_list.append(f"{key}:{keyword}")

if gen_kw_config.shouldUseLogScale(keyword):
gen_kw_list.append(f"LOG10_{key}:{keyword}")

return sorted(gen_kw_list, key=lambda k: k.lower())

def load_all_gen_kw_data(
self,
group: Optional[str] = None,
Expand Down Expand Up @@ -362,9 +362,8 @@ def load_all_gen_kw_data(
gen_kws = [config for config in gen_kws if config.name == group]
for key in gen_kws:
try:
ds = self.load_parameters(
key.name, list(realizations), var="transformed_values"
)
ds = self.load_parameters(key.name, realizations)["transformed_values"]
assert isinstance(ds, xr.DataArray)
ds["names"] = np.char.add(f"{key.name}:", ds["names"].astype(np.str_))
df = ds.to_dataframe().unstack(level="names")
df.columns = df.columns.droplevel()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def test_that_posterior_has_lower_variance_than_prior(
facade = LibresFacade.from_config_file("poly.ert")
with open_storage(facade.enspath) as storage:
default_fs = storage.get_ensemble_by_name("default")
df_default = facade.load_all_gen_kw_data(default_fs)
df_default = default_fs.load_all_gen_kw_data()
target_fs = storage.get_ensemble_by_name("target")
df_target = facade.load_all_gen_kw_data(target_fs)
df_target = target_fs.load_all_gen_kw_data()

# We expect that ERT's update step lowers the
# generalized variance for the parameters.
Expand Down
14 changes: 5 additions & 9 deletions tests/integration_tests/cli/test_integration_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,12 @@ def test_es_mda(tmpdir, source_root, snapshot, try_queue_and_scheduler, monkeypa

run_cli(parsed)
FeatureToggling.reset()
facade = LibresFacade.from_config_file("poly.ert")

with open_storage("storage", "r") as storage:
data = []
for iter_nr in range(4):
data.append(
facade.load_all_gen_kw_data(
storage.get_ensemble_by_name(f"iter-{iter_nr}")
)
)
ensemble = storage.get_ensemble_by_name(f"iter-{iter_nr}")
data.append(ensemble.load_all_gen_kw_data())
result = pd.concat(
data,
keys=[f"iter-{iter}" for iter in range(len(data))],
Expand Down Expand Up @@ -306,12 +303,11 @@ def _run(target):
)
run_cli(parsed)
facade = LibresFacade.from_config_file("poly.ert")

with open_storage(facade.enspath) as storage:
iter_0_fs = storage.get_ensemble_by_name(f"{target}-0")
df_iter_0 = facade.load_all_gen_kw_data(iter_0_fs)
df_iter_0 = iter_0_fs.load_all_gen_kw_data()
iter_1_fs = storage.get_ensemble_by_name(f"{target}-1")
df_iter_1 = facade.load_all_gen_kw_data(iter_1_fs)
df_iter_1 = iter_1_fs.load_all_gen_kw_data()

result = pd.concat(
[df_iter_0, df_iter_1],
Expand Down
2 changes: 0 additions & 2 deletions tests/unit_tests/cli/test_integration_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import json
import logging
import os

from argparse import ArgumentParser
from pathlib import Path
from textwrap import dedent
Expand All @@ -13,7 +12,6 @@
import pytest
import xtgeo


from ert import ensemble_evaluator
from ert.__main__ import ert_parser
from ert.cli import (
Expand Down

0 comments on commit 6d6e86f

Please sign in to comment.