Skip to content

Commit

Permalink
Raise if params not registered are saved
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Dec 13, 2023
1 parent 374c4b7 commit dcc4ecd
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 24 deletions.
3 changes: 3 additions & 0 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ def save_parameters(
f"must contain a 'values' variable"
)

if group not in self.experiment.parameter_configuration:
raise ValueError(f"{group} is not registered to the experiment.")

path = self.mount_point / f"realization-{realization}" / f"{group}.nc"
path.parent.mkdir(exist_ok=True)

Expand Down
5 changes: 3 additions & 2 deletions tests/unit_tests/config/test_surface_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def surface():


def test_runpath_roundtrip(tmp_path, storage, surface):
ensemble = storage.create_experiment().create_ensemble(name="text", ensemble_size=1)
config = SurfaceConfig(
"some_name",
forward_init=True,
Expand All @@ -42,7 +41,9 @@ def test_runpath_roundtrip(tmp_path, storage, surface):
output_file=tmp_path / "output",
base_surface_path="base_surface",
)

ensemble = storage.create_experiment(parameters=[config]).create_ensemble(
name="text", ensemble_size=1
)
surface.to_file(tmp_path / "input_0", fformat="irap_ascii")

# run_path -> storage
Expand Down
32 changes: 23 additions & 9 deletions tests/unit_tests/storage/migration/test_block_fs_snake_oil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import ert.storage
import ert.storage.migration._block_fs_native as bfn
import ert.storage.migration.block_fs as bf
from ert.config import ErtConfig
from ert.config import ErtConfig, GenKwConfig
from ert.storage import open_storage
from ert.storage.local_storage import local_storage_set_ert_config


Expand Down Expand Up @@ -61,17 +62,30 @@ def time_map(enspath):
return bf._load_timestamps(enspath / "default_0/files/time-map")


def test_migrate_gen_kw(data, ensemble, parameter, ens_config):
def test_migrate_gen_kw(data, parameter, ens_config, tmp_path):
group_root = "/REAL_0/GEN_KW"
bf._migrate_gen_kw(ensemble, parameter, ens_config)
with open_storage(tmp_path / "storage", mode="w") as storage:
experiment = storage.create_experiment(
parameters=[
GenKwConfig(
name="SNAKE_OIL_PARAM",
forward_init=False,
template_file="",
transfer_function_definitions=[],
output_file="kw.txt",
)
]
)
ensemble = experiment.create_ensemble(name="default_0", ensemble_size=5)
bf._migrate_gen_kw(ensemble, parameter, ens_config)

for param in ens_config.parameters:
expect_names = list(data[f"{group_root}/{param}"]["name"])
expect_array = np.array(data[f"{group_root}/{param}"]["standard_normal"])
actual = ensemble.load_parameters(param, 0)
for param in ens_config.parameters:
expect_names = list(data[f"{group_root}/{param}"]["name"])
expect_array = np.array(data[f"{group_root}/{param}"]["standard_normal"])
actual = ensemble.load_parameters(param, 0)

assert expect_names == list(actual["names"]), param
assert (expect_array == actual).all(), param
assert expect_names == list(actual["names"]), param
assert (expect_array == actual).all(), param


def test_migrate_summary(data, ensemble, forecast, time_map):
Expand Down
91 changes: 78 additions & 13 deletions tests/unit_tests/storage/test_local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import xtgeo
from resdata.grid import GridGenerator

from ert.config.field import Field
from ert.field_utils import FieldFileFormat
from ert.storage import open_storage


Expand All @@ -13,8 +15,27 @@ def test_that_egrid_files_are_saved_and_loaded_correctly(tmp_path):
mask = grid.get_actnum()
mask_values = [True] * 3 + [False] * 16 + [True]
mask.values = mask_values
grid_file = str(tmp_path / "grid.EGRID")
grid.to_file(grid_file, fformat="egrid")
param_group = "MY_PARAM"

field_config = Field(
name=param_group,
forward_init=True,
nx=grid.nrow,
ny=grid.ncol,
nz=grid.nlay,
file_format=FieldFileFormat.GRDECL,
output_transformation=None,
input_transformation=None,
truncation_min=None,
truncation_max=None,
forward_init_file="",
output_file="",
grid_file=grid_file,
)

experiment = storage.create_experiment()
experiment = storage.create_experiment(parameters=[field_config])
ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=2)
ensemble_dir = tmp_path / "ensembles" / str(ensemble.id)
assert ensemble_dir.exists()
Expand All @@ -34,31 +55,49 @@ def test_that_egrid_files_are_saved_and_loaded_correctly(tmp_path):


def test_that_grid_files_are_saved_and_loaded_correctly(tmp_path):
with open_storage(tmp_path, mode="w") as storage:
experiment = storage.create_experiment()
ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=2)
ensemble_dir = tmp_path / "ensembles" / str(ensemble.id)
assert ensemble_dir.exists()

with open_storage(tmp_path / "storage", mode="w") as storage:
mask = [True] * 3 + [False] * 16 + [True]
grid = GridGenerator.create_rectangular((4, 5, 1), (1, 1, 1), actnum=mask)
grid.save_GRID(f"{experiment.mount_point}/grid.GRID")
grid_file = str(storage.path / "grid.GRID")
grid.save_GRID(grid_file)
param_group = "MY_PARAM"
field_config = Field(
name=param_group,
forward_init=True,
nx=grid.nx,
ny=grid.ny,
nz=grid.nz,
file_format=FieldFileFormat.GRDECL,
output_transformation=None,
input_transformation=None,
truncation_min=None,
truncation_max=None,
forward_init_file="",
output_file="",
grid_file=grid_file,
)
experiment = storage.create_experiment(parameters=[field_config])
ensemble = storage.create_ensemble(experiment, name="foo", ensemble_size=2)
ensemble_dir = tmp_path / "storage" / "ensembles" / str(ensemble.id)
assert ensemble_dir.exists()

data = np.full_like(mask, np.nan, dtype=np.float32)
np.place(data, mask, np.array([1.2, 1.1, 4.3, 3.1], dtype=np.float32))
da = xr.DataArray(
data.reshape((4, 5, 1)),
data.reshape((grid.nx, grid.ny, grid.nz)),
name="values",
dims=["x", "y", "z"], # type: ignore
)
ds = da.to_dataset()
ensemble.save_parameters("MY_PARAM", 1, ds)
ensemble.save_parameters(param_group, 1, ds)

saved_file = ensemble_dir / "realization-1" / "MY_PARAM.nc"
saved_file = ensemble_dir / "realization-1" / f"{param_group}.nc"
assert saved_file.exists()

loaded_data = ensemble.load_parameters("MY_PARAM", 1)
np.testing.assert_array_equal(loaded_data.values, data.reshape((4, 5, 1)))
loaded_data = ensemble.load_parameters(param_group, 1)
np.testing.assert_array_equal(
loaded_data.values, data.reshape((grid.nx, grid.ny, grid.nz))
)


def test_that_load_responses_throws_exception(tmp_path):
Expand All @@ -77,3 +116,29 @@ def test_that_load_parameters_throws_exception(tmp_path):

with pytest.raises(expected_exception=KeyError):
ensemble.load_parameters("I_DONT_EXIST", 1)


def test_that_only_registered_parameters_can_be_saved(tmp_path):
with open_storage(tmp_path, mode="w") as storage:
experiment = storage.create_experiment()
prior = storage.create_ensemble(
experiment,
ensemble_size=1,
iteration=0,
name="prior",
)

with pytest.raises(
ValueError, match="PARAMETER is not registered to the experiment."
):
prior.save_parameters(
"PARAMETER",
0,
xr.Dataset(
{
"values": ("names", [1.0]),
"transformed_values": ("names", [1.0]),
"names": ["KEY_1"],
}
),
)

0 comments on commit dcc4ecd

Please sign in to comment.