Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kvashchuka committed Dec 13, 2023
1 parent cf0e559 commit 9eae710
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/ert/config/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def write_to_runpath(
def _fetch_from_ensemble(
self, real_nr: int, ensemble: EnsembleReader
) -> xr.DataArray:
da = ensemble.load_parameters(self.name, real_nr)
da = ensemble.load_parameters(self.name, real_nr)["values"]
assert isinstance(da, xr.DataArray)
return da

Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/surface_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def read_from_runpath(self, run_path: Path, real_nr: int) -> xr.Dataset:
def write_to_runpath(
self, run_path: Path, real_nr: int, ensemble: EnsembleReader
) -> None:
data = ensemble.load_parameters(self.name, real_nr)
data = ensemble.load_parameters(self.name, real_nr)["values"]

surf = xtgeo.RegularSurface(
ncol=self.ncol,
Expand Down
12 changes: 8 additions & 4 deletions tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,12 @@ def test_update_snapshot(
np.random.default_rng(3593114179000630026631423308983283277868),
)

sim_gen_kw = list(prior_ens.load_parameters("SNAKE_OIL_PARAM", 0).values.flatten())
sim_gen_kw = list(
prior_ens.load_parameters("SNAKE_OIL_PARAM", 0)["values"].values.flatten()
)

target_gen_kw = list(
posterior_ens.load_parameters("SNAKE_OIL_PARAM", 0).values.flatten()
posterior_ens.load_parameters("SNAKE_OIL_PARAM", 0)["values"].values.flatten()
)

assert sim_gen_kw != target_gen_kw
Expand Down Expand Up @@ -386,10 +388,12 @@ def test_localization(
rng=np.random.default_rng(3593114179000630026631423308983283277868),
)

sim_gen_kw = list(prior_ens.load_parameters("SNAKE_OIL_PARAM", 0).values.flatten())
sim_gen_kw = list(
prior_ens.load_parameters("SNAKE_OIL_PARAM", 0)["values"].values.flatten()
)

target_gen_kw = list(
posterior_ens.load_parameters("SNAKE_OIL_PARAM", 0).values.flatten()
posterior_ens.load_parameters("SNAKE_OIL_PARAM", 0)["values"].values.flatten()
)

# Test that the localized values has been updated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_migrate_surface(data, storage, parameter, ens_config):

for key, var in data["/REAL_0/SURFACE"].groups.items():
expect = sorted_surface(var)
actual = ensemble.load_parameters(key, 0).values.ravel()
actual = ensemble.load_parameters(key, 0)["values"].values.ravel()
assert list(expect) == list(actual), key


Expand All @@ -84,7 +84,7 @@ def test_migrate_field(data, storage, parameter, ens_config):

for key, var in data["/REAL_0/FIELD"].groups.items():
expect = np.array(var["VALUES"]).ravel()
actual = ensemble.load_parameters(key, [0]).values.ravel()
actual = ensemble.load_parameters(key, [0])["values"].values.ravel()
assert list(expect) == list(actual), key


Expand All @@ -101,13 +101,13 @@ def test_migrate_case(data, storage, enspath):
# Compare FIELDs
for key, data in real_group["FIELD"].groups.items():
expect = np.array(data["VALUES"]).ravel()
actual = ensemble.load_parameters(key, [real_index])
actual = ensemble.load_parameters(key, [real_index])["values"]
assert list(expect) == list(actual.values.ravel()), f"FIELD {key}"

# Compare SURFACEs
for key, data in real_group["SURFACE"].groups.items():
expect = sorted_surface(data)
actual = ensemble.load_parameters(key, real_index).values.ravel()
actual = ensemble.load_parameters(key, real_index)["values"].values.ravel()
assert list(expect) == list(actual), f"SURFACE {key}"


Expand Down
17 changes: 8 additions & 9 deletions tests/unit_tests/storage/test_parameter_sample_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ def test_that_order_of_input_in_user_input_is_abritrary_for_gen_kw_init_files(
def test_surface_param(
storage,
tmpdir,
config_str,
expect_forward_init,
expect_num_loaded,
error,
Expand Down Expand Up @@ -550,9 +549,9 @@ def test_that_sampling_is_fixed_from_name(
key_hash = sha256(b"1234" + b"KW_NAME:MY_KEYWORD")
seed = np.frombuffer(key_hash.digest(), dtype="uint32")
expected = np.random.default_rng(seed).standard_normal(num_realisations)
assert fs.load_parameters("KW_NAME").sel(
names="MY_KEYWORD"
).values.ravel().tolist() == list(expected)
assert fs.load_parameters("KW_NAME").sel(names="MY_KEYWORD")[
"values"
].values.ravel().tolist() == list(expected)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -614,7 +613,7 @@ def test_that_sub_sample_maintains_order(tmpdir, storage, mask, expected):

assert (
fs.load_parameters("KW_NAME")
.sel(names="MY_KEYWORD")
.sel(names="MY_KEYWORD")["values"]
.values.ravel()
.tolist()
== expected
Expand Down Expand Up @@ -740,7 +739,7 @@ def test_surface_param_update(tmpdir):
.T
)
posterior_param = (
posterior.load_parameters("MY_PARAM", range(5))
posterior.load_parameters("MY_PARAM", range(5))["values"]
.values.reshape(5, 2 * 3)
.T
)
Expand Down Expand Up @@ -914,6 +913,6 @@ def test_gen_kw_optional_template(storage, tmpdir, config_str, expected):
fh.writelines("MY_KEYWORD NORMAL 0 1")

create_runpath(storage, "config.ert")
assert list(storage.ensembles)[0].load_parameters(
"KW_NAME"
).values.flatten().tolist() == pytest.approx([expected])
assert list(storage.ensembles)[0].load_parameters("KW_NAME")[
"values"
].values.flatten().tolist() == pytest.approx([expected])

0 comments on commit 9eae710

Please sign in to comment.