Skip to content

Commit

Permalink
Change load_parameters function to return only Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
kvashchuka committed Dec 20, 2023
1 parent eb8b6f2 commit 88b1d4b
Show file tree
Hide file tree
Showing 14 changed files with 95 additions and 78 deletions.
20 changes: 11 additions & 9 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,17 @@ def _create_temporary_parameter_storage(
matrix: Union[npt.NDArray[np.double], xr.DataArray]
if isinstance(config_node, GenKwConfig):
t = time.perf_counter()
matrix = source_fs.load_parameters(param_group, iens_active_index).values.T # type: ignore
matrix = source_fs.load_parameters(param_group, iens_active_index)[
"values"
].values.T
t_genkw += time.perf_counter() - t
elif isinstance(config_node, SurfaceConfig):
t = time.perf_counter()
matrix = source_fs.load_parameters(param_group, iens_active_index) # type: ignore
matrix = source_fs.load_parameters(param_group, iens_active_index)["values"]
t_surface += time.perf_counter() - t
elif isinstance(config_node, Field):
t = time.perf_counter()
matrix = source_fs.load_parameters(param_group, iens_active_index) # type: ignore
matrix = source_fs.load_parameters(param_group, iens_active_index)["values"]
t_field += time.perf_counter() - t
else:
raise NotImplementedError(f"{type(config_node)} is not supported")
Expand Down Expand Up @@ -679,9 +681,7 @@ def analysis_ES(
)
for parameter_group in not_updated_parameter_groups:
for realization in iens_active_index:
ds = source_fs.load_parameters(
parameter_group, int(realization), var=None
)
ds = source_fs.load_parameters(parameter_group, int(realization))
assert isinstance(ds, xr.Dataset)
target_fs.save_parameters(
parameter_group,
Expand Down Expand Up @@ -813,7 +813,9 @@ def analysis_IES(
updated_parameter_groups.append(param_group.name)
source: Union[EnsembleReader, EnsembleAccessor] = target_fs
try:
target_fs.load_parameters(group=param_group.name, realizations=0)
target_fs.load_parameters(group=param_group.name, realizations=0)[
"values"
]
except Exception:
source = source_fs
temp_storage = _create_temporary_parameter_storage(
Expand Down Expand Up @@ -843,8 +845,8 @@ def analysis_IES(
for parameter_group in not_updated_parameter_groups:
for realization in iens_active_index:
prior_dataset = source_fs.load_parameters(
parameter_group, int(realization), var=None
)
parameter_group, int(realization)
)["values"]
assert isinstance(prior_dataset, xr.Dataset)
target_fs.save_parameters(
parameter_group,
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/ext_param_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def write_to_runpath(
Path.mkdir(file_path.parent, exist_ok=True, parents=True)

data: MutableDataType = {}
for da in ensemble.load_parameters(self.name, real_nr):
for da in ensemble.load_parameters(self.name, real_nr)["values"]:
assert isinstance(da, xr.DataArray)
name = str(da.names.values)
try:
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def write_to_runpath(
def _fetch_from_ensemble(
self, real_nr: int, ensemble: EnsembleReader
) -> xr.DataArray:
da = ensemble.load_parameters(self.name, real_nr)
da = ensemble.load_parameters(self.name, real_nr)["values"]
assert isinstance(da, xr.DataArray)
return da

Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def read_from_runpath(
def write_to_runpath(
self, run_path: Path, real_nr: int, ensemble: EnsembleReader
) -> Dict[str, Dict[str, float]]:
array = ensemble.load_parameters(self.name, real_nr, var="transformed_values")
array = ensemble.load_parameters(self.name, real_nr)["transformed_values"]
assert isinstance(array, xr.DataArray)
if not array.size == len(self.transfer_functions):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/surface_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def read_from_runpath(self, run_path: Path, real_nr: int) -> xr.Dataset:
def write_to_runpath(
self, run_path: Path, real_nr: int, ensemble: EnsembleReader
) -> None:
data = ensemble.load_parameters(self.name, real_nr)
data = ensemble.load_parameters(self.name, real_nr)["values"]

surf = xtgeo.RegularSurface(
ncol=self.ncol,
Expand Down
6 changes: 3 additions & 3 deletions src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ def load_all_gen_kw_data(
gen_kws = [config for config in gen_kws if config.name == group]
for key in gen_kws:
try:
ds = ensemble.load_parameters(
key.name, realizations, var="transformed_values"
)
ds = ensemble.load_parameters(key.name, realizations)[
"transformed_values"
]
assert isinstance(ds, xr.DataArray)
ds["names"] = np.char.add(f"{key.name}:", ds["names"].astype(np.str_))
df = ds.to_dataframe().unstack(level="names")
Expand Down
12 changes: 3 additions & 9 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,9 @@ def has_parameter_group(self, group: str) -> bool:
return param_group_file.exists()

def load_parameters(
self,
group: str,
realizations: Union[int, npt.NDArray[np.int_], None] = None,
*,
var: Optional[str] = "values",
) -> Union[xr.DataArray, xr.Dataset]:
if var is None:
return self._load_dataset(group, realizations)
return self._load_dataset(group, realizations)[var]
self, group: str, realizations: Union[int, npt.NDArray[np.int_], None] = None
) -> xr.Dataset:
return self._load_dataset(group, realizations)

@lru_cache # noqa: B019
def load_responses(
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/analysis/test_adaptive_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def run_cli_ES_with_case(poly_config):
storage_path = ErtConfig.from_file(poly_config).ens_path
with open_storage(storage_path) as storage:
prior_ensemble = storage.get_ensemble_by_name(prior_sample_name)
prior_sample = prior_ensemble.load_parameters("COEFFS")
prior_sample = prior_ensemble.load_parameters("COEFFS")["values"]
posterior_ensemble = storage.get_ensemble_by_name(posterior_sample_name)
posterior_sample = posterior_ensemble.load_parameters("COEFFS")
posterior_sample = posterior_ensemble.load_parameters("COEFFS")["values"]
return prior_sample, posterior_sample


Expand Down
33 changes: 20 additions & 13 deletions tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,12 @@ def test_update_snapshot(
rng=rng,
)

sim_gen_kw = list(prior_ens.load_parameters("SNAKE_OIL_PARAM", 0).values.flatten())
sim_gen_kw = list(
prior_ens.load_parameters("SNAKE_OIL_PARAM", 0)["values"].values.flatten()
)

target_gen_kw = list(
posterior_ens.load_parameters("SNAKE_OIL_PARAM", 0).values.flatten()
posterior_ens.load_parameters("SNAKE_OIL_PARAM", 0)["values"].values.flatten()
)

# Check that prior is not equal to posterior after updationg
Expand Down Expand Up @@ -414,10 +416,12 @@ def test_localization(
rng=np.random.default_rng(42),
)

sim_gen_kw = list(prior_ens.load_parameters("SNAKE_OIL_PARAM", 0).values.flatten())
sim_gen_kw = list(
prior_ens.load_parameters("SNAKE_OIL_PARAM", 0)["values"].values.flatten()
)

target_gen_kw = list(
posterior_ens.load_parameters("SNAKE_OIL_PARAM", 0).values.flatten()
posterior_ens.load_parameters("SNAKE_OIL_PARAM", 0)["values"].values.flatten()
)

# Test that the localized values has been updated
Expand Down Expand Up @@ -589,14 +593,16 @@ def sample_prior(nx, ny):
ens_posterior = storage.get_ensemble_by_name("es_udpate")

# Check that surfaces defined in INIT_FILES are not changed by ERT
surf_prior = ens_prior.load_parameters("TOP", list(range(ensemble_size)))
surf_prior = ens_prior.load_parameters("TOP", list(range(ensemble_size)))["values"]
for i in range(ensemble_size):
_prior_init = xtgeo.surface_from_file(
f"surface/surf_init_{i}.irap", fformat="irap_ascii", dtype=np.float32
)
np.testing.assert_array_equal(surf_prior[i], _prior_init.values.data)

surf_posterior = ens_posterior.load_parameters("TOP", list(range(ensemble_size)))
surf_posterior = ens_posterior.load_parameters("TOP", list(range(ensemble_size)))[
"values"
]

assert surf_prior.shape == surf_posterior.shape

Expand Down Expand Up @@ -643,7 +649,7 @@ def _load_parameters(source_ens, iens_active_index, param_groups):
temp_storage[param_group] = _temp_storage[param_group]
return temp_storage

sim_fs.load_parameters("SNAKE_OIL_PARAM_BPR")
sim_fs.load_parameters("SNAKE_OIL_PARAM_BPR")["values"]
param_groups = list(sim_fs.experiment.parameter_configuration.keys())
prior = _load_parameters(sim_fs, list(range(10)), param_groups)
posterior = _load_parameters(posterior_fs, list(range(10)), param_groups)
Expand Down Expand Up @@ -799,11 +805,12 @@ def g(X):
)
benchmark(smoother_update_run)

prior_da = prior.load_parameters(param_group, range(num_ensemble))
posterior_da = posterior_ens.load_parameters(param_group, range(num_ensemble))
prior_da = prior.load_parameters(param_group, range(num_ensemble))["values"]
posterior_da = posterior_ens.load_parameters(param_group, range(num_ensemble))["values"]
# Make sure some, but not all parameters were updated.
assert not np.allclose(prior_da, posterior_da)
# All parameters would be updated with a global update so this would fail.

assert np.isclose(prior_da, posterior_da).sum() > 0


Expand Down Expand Up @@ -1046,9 +1053,9 @@ def test_update_subset_parameters(storage, uniform_parameter, obs):
smoother_update(
prior, posterior_ens, "id", update_config, UpdateSettings(), ESSettings()
)
assert prior.load_parameters("EXTRA_PARAMETER", 0).equals(
posterior_ens.load_parameters("EXTRA_PARAMETER", 0)
assert prior.load_parameters("EXTRA_PARAMETER", 0)["values"].equals(
posterior_ens.load_parameters("EXTRA_PARAMETER", 0)["values"]
)
assert not prior.load_parameters("PARAMETER", 0).equals(
posterior_ens.load_parameters("PARAMETER", 0)
assert not prior.load_parameters("PARAMETER", 0)["values"].equals(
posterior_ens.load_parameters("PARAMETER", 0)["values"]
)
8 changes: 6 additions & 2 deletions tests/unit_tests/cli/test_integration_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ def test_that_prior_is_not_overwritten_in_ensemble_experiment(
experiment_id, name="iter-0", ensemble_size=num_realizations
)
sample_prior(ensemble, prior_mask)
prior_values = storage.get_ensemble(ensemble.id).load_parameters("COEFFS")
prior_values = storage.get_ensemble(ensemble.id).load_parameters("COEFFS")[
"values"
]
storage.close()

parser = ArgumentParser(prog="test_main")
Expand All @@ -427,7 +429,9 @@ def test_that_prior_is_not_overwritten_in_ensemble_experiment(
FeatureToggling.update_from_args(parsed)
run_cli(parsed)
storage = open_storage(ert_config.ens_path, mode="w")
parameter_values = storage.get_ensemble(ensemble.id).load_parameters("COEFFS")
parameter_values = storage.get_ensemble(ensemble.id).load_parameters("COEFFS")[
"values"
]

if should_resample:
with pytest.raises(AssertionError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_migrate_surface(data, storage, parameter, ens_config):

for key, var in data["/REAL_0/SURFACE"].groups.items():
expect = sorted_surface(var)
actual = ensemble.load_parameters(key, 0).values.ravel()
actual = ensemble.load_parameters(key, 0)["values"].values.ravel()
assert list(expect) == list(actual), key


Expand All @@ -84,7 +84,7 @@ def test_migrate_field(data, storage, parameter, ens_config):

for key, var in data["/REAL_0/FIELD"].groups.items():
expect = np.array(var["VALUES"]).ravel()
actual = ensemble.load_parameters(key, [0]).values.ravel()
actual = ensemble.load_parameters(key, [0])["values"].values.ravel()
assert list(expect) == list(actual), key


Expand All @@ -101,13 +101,13 @@ def test_migrate_case(data, storage, enspath):
# Compare FIELDs
for key, data in real_group["FIELD"].groups.items():
expect = np.array(data["VALUES"]).ravel()
actual = ensemble.load_parameters(key, [real_index])
actual = ensemble.load_parameters(key, [real_index])["values"]
assert list(expect) == list(actual.values.ravel()), f"FIELD {key}"

# Compare SURFACEs
for key, data in real_group["SURFACE"].groups.items():
expect = sorted_surface(data)
actual = ensemble.load_parameters(key, real_index).values.ravel()
actual = ensemble.load_parameters(key, real_index)["values"].values.ravel()
assert list(expect) == list(actual), f"SURFACE {key}"


Expand Down
34 changes: 19 additions & 15 deletions tests/unit_tests/storage/test_field_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def test_load_two_parameters_forward_init(storage, tmpdir):
with pytest.raises(
KeyError, match="No dataset 'PARAM_A' in storage for realization 0"
):
_ = fs.load_parameters("PARAM_A", [0])
_ = fs.load_parameters("PARAM_A", [0])["values"]

with pytest.raises(
KeyError, match="No dataset 'PARAM_B' in storage for realization 0"
):
_ = fs.load_parameters("PARAM_B", [0])
_ = fs.load_parameters("PARAM_B", [0])["values"]

assert load_from_forward_model("config.ert", fs, 0) == 1

Expand All @@ -142,10 +142,10 @@ def test_load_two_parameters_forward_init(storage, tmpdir):
numpy.testing.assert_equal(prop_b.values.data, param_b)

# should be loaded now
loaded_a = fs.load_parameters("PARAM_A", [0])
loaded_a = fs.load_parameters("PARAM_A", [0])["values"]
assert (loaded_a.values == 22).all()

loaded_b = fs.load_parameters("PARAM_B", [0])
loaded_b = fs.load_parameters("PARAM_B", [0])["values"]
assert (loaded_b.values == 77).all()


Expand Down Expand Up @@ -176,10 +176,10 @@ def test_load_two_parameters_roff(storage, tmpdir):
assert not ensemble_config["PARAM_A"].forward_init
assert not ensemble_config["PARAM_B"].forward_init

loaded_a = fs.load_parameters("PARAM_A", [0])
loaded_a = fs.load_parameters("PARAM_A", [0])["values"]
assert (loaded_a.values == 22).all()

loaded_b = fs.load_parameters("PARAM_B", [0])
loaded_b = fs.load_parameters("PARAM_B", [0])["values"]
assert (loaded_b.values == 77).all()

prop_a = xtgeo.gridproperty_from_file(
Expand Down Expand Up @@ -232,10 +232,10 @@ def test_load_two_parameters(storage, tmpdir):
assert not ensemble_config["PARAM_A"].forward_init
assert not ensemble_config["PARAM_B"].forward_init

loaded_a = fs.load_parameters("PARAM_A", [0])
loaded_a = fs.load_parameters("PARAM_A", [0])["values"]
assert (loaded_a.values == 22).all()

loaded_b = fs.load_parameters("PARAM_B", [0])
loaded_b = fs.load_parameters("PARAM_B", [0])["values"]
assert (loaded_b.values == 77).all()

prop_a = xtgeo.gridproperty_from_file(
Expand Down Expand Up @@ -350,7 +350,7 @@ def test_transformation(storage, tmpdir):
_, fs = create_runpath(storage, "config.ert", [True, True])

# stored internally as 2.5, 1.5
loaded_a = fs.load_parameters("PARAM_A", [0, 1])
loaded_a = fs.load_parameters("PARAM_A", [0, 1])["values"]
assert np.isclose(loaded_a.values[0], 2.5).all()
assert np.isclose(loaded_a.values[1], 1.5).all()

Expand Down Expand Up @@ -419,7 +419,7 @@ def test_forward_init(storage, tmpdir, config_str, expect_forward_init):
with pytest.raises(
KeyError, match="No dataset 'MY_PARAM' in storage for realization 0"
):
fs.load_parameters("MY_PARAM", [0])
fs.load_parameters("MY_PARAM", [0])["values"]

# We try to load the parameters from the forward model, this would fail if
# forward init was not set correctly
Expand All @@ -437,7 +437,7 @@ def test_forward_init(storage, tmpdir, config_str, expect_forward_init):
numpy.testing.assert_equal(prop.values.data, expect_param)

if expect_forward_init:
arr = fs.load_parameters("MY_PARAM", [0])
arr = fs.load_parameters("MY_PARAM", [0])["values"]
assert len(arr.values.ravel()) == 16


Expand Down Expand Up @@ -547,12 +547,14 @@ def test_field_param_update(tmpdir):
prior = storage.get_ensemble_by_name("prior")
posterior = storage.get_ensemble_by_name("smoother_update")

prior_result = prior.load_parameters("MY_PARAM", list(range(5)))
prior_result = prior.load_parameters("MY_PARAM", list(range(5)))["values"]
assert len(prior_result.x) == NCOL
assert len(prior_result.y) == NROW
assert len(prior_result.z) == NLAY

posterior_result = posterior.load_parameters("MY_PARAM", list(range(5)))
posterior_result = posterior.load_parameters("MY_PARAM", list(range(5)))[
"values"
]
# Only assert on the first three rows, as there are only three parameters,
# a, b and c, the rest have no correlation to the results.
assert np.linalg.det(
Expand Down Expand Up @@ -674,8 +676,10 @@ def test_parameter_update_with_inactive_cells_xtgeo_grdecl(tmpdir):
prior = storage.get_ensemble_by_name("prior")
posterior = storage.get_ensemble_by_name("smoother_update")

prior_result = prior.load_parameters("MY_PARAM", list(range(5)))
posterior_result = posterior.load_parameters("MY_PARAM", list(range(5)))
prior_result = prior.load_parameters("MY_PARAM", list(range(5)))["values"]
posterior_result = posterior.load_parameters("MY_PARAM", list(range(5)))[
"values"
]

# check the shape of internal data used in the update
assert prior_result.shape == (5, NCOL, NROW, NLAY)
Expand Down
Loading

0 comments on commit 88b1d4b

Please sign in to comment.