From dcc4ecd80daf1dc811077030713a8b3cb9d7ddab Mon Sep 17 00:00:00 2001 From: Feda Curic Date: Wed, 13 Dec 2023 09:46:40 +0100 Subject: [PATCH] Raise if params not registered are saved --- src/ert/storage/local_ensemble.py | 3 + .../unit_tests/config/test_surface_config.py | 5 +- .../migration/test_block_fs_snake_oil.py | 32 +++++-- .../unit_tests/storage/test_local_ensemble.py | 91 ++++++++++++++++--- 4 files changed, 107 insertions(+), 24 deletions(-) diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index 0eedafa8b10..9fc53bcd2ea 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -337,6 +337,9 @@ def save_parameters( f"must contain a 'values' variable" ) + if group not in self.experiment.parameter_configuration: + raise ValueError(f"{group} is not registered to the experiment.") + path = self.mount_point / f"realization-{realization}" / f"{group}.nc" path.parent.mkdir(exist_ok=True) diff --git a/tests/unit_tests/config/test_surface_config.py b/tests/unit_tests/config/test_surface_config.py index f174eebbb58..8f3b8c10a66 100644 --- a/tests/unit_tests/config/test_surface_config.py +++ b/tests/unit_tests/config/test_surface_config.py @@ -26,7 +26,6 @@ def surface(): def test_runpath_roundtrip(tmp_path, storage, surface): - ensemble = storage.create_experiment().create_ensemble(name="text", ensemble_size=1) config = SurfaceConfig( "some_name", forward_init=True, @@ -42,7 +41,9 @@ def test_runpath_roundtrip(tmp_path, storage, surface): output_file=tmp_path / "output", base_surface_path="base_surface", ) - + ensemble = storage.create_experiment(parameters=[config]).create_ensemble( + name="text", ensemble_size=1 + ) surface.to_file(tmp_path / "input_0", fformat="irap_ascii") # run_path -> storage diff --git a/tests/unit_tests/storage/migration/test_block_fs_snake_oil.py b/tests/unit_tests/storage/migration/test_block_fs_snake_oil.py index db60fc0e666..84b8de8884c 100644 --- a/tests/unit_tests/storage/migration/test_block_fs_snake_oil.py +++ b/tests/unit_tests/storage/migration/test_block_fs_snake_oil.py @@ -9,7 +9,8 @@ import ert.storage import ert.storage.migration._block_fs_native as bfn import ert.storage.migration.block_fs as bf -from ert.config import ErtConfig +from ert.config import ErtConfig, GenKwConfig +from ert.storage import open_storage from ert.storage.local_storage import local_storage_set_ert_config @@ -61,17 +62,30 @@ def time_map(enspath): return bf._load_timestamps(enspath / "default_0/files/time-map") -def test_migrate_gen_kw(data, ensemble, parameter, ens_config): +def test_migrate_gen_kw(data, parameter, ens_config, tmp_path): group_root = "/REAL_0/GEN_KW" - bf._migrate_gen_kw(ensemble, parameter, ens_config) + with open_storage(tmp_path / "storage", mode="w") as storage: + experiment = storage.create_experiment( + parameters=[ + GenKwConfig( + name="SNAKE_OIL_PARAM", + forward_init=False, + template_file="", + transfer_function_definitions=[], + output_file="kw.txt", + ) + ] + ) + ensemble = experiment.create_ensemble(name="default_0", ensemble_size=5) + bf._migrate_gen_kw(ensemble, parameter, ens_config) - for param in ens_config.parameters: - expect_names = list(data[f"{group_root}/{param}"]["name"]) - expect_array = np.array(data[f"{group_root}/{param}"]["standard_normal"]) - actual = ensemble.load_parameters(param, 0) + for param in ens_config.parameters: + expect_names = list(data[f"{group_root}/{param}"]["name"]) + expect_array = np.array(data[f"{group_root}/{param}"]["standard_normal"]) + actual = ensemble.load_parameters(param, 0) - assert expect_names == list(actual["names"]), param - assert (expect_array == actual).all(), param + assert expect_names == list(actual["names"]), param + assert (expect_array == actual).all(), param def test_migrate_summary(data, ensemble, forecast, time_map): diff --git a/tests/unit_tests/storage/test_local_ensemble.py b/tests/unit_tests/storage/test_local_ensemble.py index a3878fee20f..de77eca8aba 100644 --- a/tests/unit_tests/storage/test_local_ensemble.py +++ b/tests/unit_tests/storage/test_local_ensemble.py @@ -4,6 +4,8 @@ import xtgeo from resdata.grid import GridGenerator +from ert.config.field import Field +from ert.field_utils import FieldFileFormat from ert.storage import open_storage @@ -13,8 +15,27 @@ def test_that_egrid_files_are_saved_and_loaded_correctly(tmp_path): mask = grid.get_actnum() mask_values = [True] * 3 + [False] * 16 + [True] mask.values = mask_values + grid_file = str(tmp_path / "grid.EGRID") + grid.to_file(grid_file, fformat="egrid") + param_group = "MY_PARAM" + + field_config = Field( + name=param_group, + forward_init=True, + nx=grid.nrow, + ny=grid.ncol, + nz=grid.nlay, + file_format=FieldFileFormat.GRDECL, + output_transformation=None, + input_transformation=None, + truncation_min=None, + truncation_max=None, + forward_init_file="", + output_file="", + grid_file=grid_file, + ) - experiment = storage.create_experiment() + experiment = storage.create_experiment(parameters=[field_config]) ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=2) ensemble_dir = tmp_path / "ensembles" / str(ensemble.id) assert ensemble_dir.exists() @@ -34,31 +55,49 @@ def test_that_egrid_files_are_saved_and_loaded_correctly(tmp_path): def test_that_grid_files_are_saved_and_loaded_correctly(tmp_path): - with open_storage(tmp_path, mode="w") as storage: - experiment = storage.create_experiment() - ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=2) - ensemble_dir = tmp_path / "ensembles" / str(ensemble.id) - assert ensemble_dir.exists() - + with open_storage(tmp_path / "storage", mode="w") as storage: mask = [True] * 3 + [False] * 16 + [True] grid = GridGenerator.create_rectangular((4, 5, 1), (1, 1, 1), actnum=mask) - grid.save_GRID(f"{experiment.mount_point}/grid.GRID") + grid_file = str(storage.path / "grid.GRID") + grid.save_GRID(grid_file) + param_group = "MY_PARAM" + field_config = Field( + name=param_group, + forward_init=True, + nx=grid.nx, + ny=grid.ny, + nz=grid.nz, + file_format=FieldFileFormat.GRDECL, + output_transformation=None, + input_transformation=None, + truncation_min=None, + truncation_max=None, + forward_init_file="", + output_file="", + grid_file=grid_file, + ) + experiment = storage.create_experiment(parameters=[field_config]) + ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=2) + ensemble_dir = tmp_path / "storage" / "ensembles" / str(ensemble.id) + assert ensemble_dir.exists() data = np.full_like(mask, np.nan, dtype=np.float32) np.place(data, mask, np.array([1.2, 1.1, 4.3, 3.1], dtype=np.float32)) da = xr.DataArray( - data.reshape((4, 5, 1)), + data.reshape((grid.nx, grid.ny, grid.nz)), name="values", dims=["x", "y", "z"], # type: ignore ) ds = da.to_dataset() - ensemble.save_parameters("MY_PARAM", 1, ds) + ensemble.save_parameters(param_group, 1, ds) - saved_file = ensemble_dir / "realization-1" / "MY_PARAM.nc" + saved_file = ensemble_dir / "realization-1" / f"{param_group}.nc" assert saved_file.exists() - loaded_data = ensemble.load_parameters("MY_PARAM", 1) - np.testing.assert_array_equal(loaded_data.values, data.reshape((4, 5, 1))) + loaded_data = ensemble.load_parameters(param_group, 1) + np.testing.assert_array_equal( + loaded_data.values, data.reshape((grid.nx, grid.ny, grid.nz)) + ) def test_that_load_responses_throws_exception(tmp_path): @@ -77,3 +116,29 @@ def test_that_load_parameters_throws_exception(tmp_path): with pytest.raises(expected_exception=KeyError): ensemble.load_parameters("I_DONT_EXIST", 1) + + +def test_that_only_registered_parameters_can_be_saved(tmp_path): + with open_storage(tmp_path, mode="w") as storage: + experiment = storage.create_experiment() + prior = storage.create_ensemble( + experiment, + ensemble_size=1, + iteration=0, + name="prior", + ) + + with pytest.raises( + ValueError, match="PARAMETER is not registered to the experiment." + ): + prior.save_parameters( + "PARAMETER", + 0, + xr.Dataset( + { + "values": ("names", [1.0]), + "transformed_values": ("names", [1.0]), + "names": ["KEY_1"], + } + ), + )