Skip to content

Commit

Permalink
Make tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Sep 16, 2024
1 parent 9842058 commit af82675
Show file tree
Hide file tree
Showing 10 changed files with 1,137 additions and 1,162 deletions.
17 changes: 10 additions & 7 deletions tests/integration_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from textwrap import dedent

import numpy as np
import polars
import pytest
import xarray as xr
from scipy.ndimage import gaussian_filter
Expand Down Expand Up @@ -38,14 +39,16 @@ def uniform_parameter():


@pytest.fixture
def obs():
return xr.Dataset(
def obs() -> polars.DataFrame:
return polars.DataFrame(
{
"observations": (["report_step", "index"], [[1.0, 1.0, 1.0]]),
"std": (["report_step", "index"], [[0.1, 1.0, 10.0]]),
},
coords={"index": [0, 1, 2], "report_step": [0]},
attrs={"response": "RESPONSE"},
"response_key": "RESPONSE",
"observation_key": "OBSERVATION",
"report_step": polars.Series(np.full(3, 0), dtype=polars.UInt16),
"index": polars.Series([0, 1, 2], dtype=polars.UInt16),
"observations": [1.0, 1.0, 1.0],
"std": [0.1, 1.0, 10.0],
}
)


Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

94 changes: 41 additions & 53 deletions tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest.mock import patch

import numpy as np
import polars
import pytest
import scipy as sp
import xarray as xr
Expand Down Expand Up @@ -48,14 +49,16 @@ def uniform_parameter():


@pytest.fixture
def obs():
return xr.Dataset(
def obs() -> polars.DataFrame:
return polars.DataFrame(
{
"observations": (["report_step", "index"], [[1.0, 1.0, 1.0]]),
"std": (["report_step", "index"], [[0.1, 1.0, 10.0]]),
},
coords={"index": [0, 1, 2], "report_step": [0]},
attrs={"response": "RESPONSE"},
"response_key": "RESPONSE",
"observation_key": "OBSERVATION",
"report_step": polars.Series(np.full(3, 0), dtype=polars.UInt16),
"index": polars.Series([0, 1, 2], dtype=polars.UInt16),
"observations": [1.0, 1.0, 1.0],
"std": [0.1, 1.0, 10.0],
}
)


Expand Down Expand Up @@ -98,7 +101,7 @@ def test_update_report(
smoother_update(
prior_ens,
posterior_ens,
list(experiment.observation_keys),
experiment.observation_keys,
ert_config.ensemble_config.parameters,
UpdateSettings(auto_scale_observations=misfit_preprocess),
ESSettings(inversion="subspace"),
Expand Down Expand Up @@ -134,7 +137,7 @@ def test_update_report_with_exception_in_analysis_ES(
smoother_update(
prior_ens,
posterior_ens,
list(ert_config.observations.keys()),
experiment.observation_keys,
ert_config.ensemble_config.parameters,
UpdateSettings(alpha=0.0000000001),
ESSettings(inversion="subspace"),
Expand Down Expand Up @@ -177,7 +180,7 @@ def test_update_report_with_different_observation_status_from_smoother_update(
ss = smoother_update(
prior_ens,
posterior_ens,
list(ert_config.observations.keys()),
experiment.observation_keys,
ert_config.ensemble_config.parameters,
update_settings,
ESSettings(inversion="subspace"),
Expand Down Expand Up @@ -293,7 +296,7 @@ def test_update_snapshot(
prior_storage=prior_ens,
posterior_storage=posterior_ens,
sies_smoother=sies_smoother,
observations=list(ert_config.observations.keys()),
observations=experiment.observation_keys,
parameters=list(ert_config.ensemble_config.parameters),
update_settings=UpdateSettings(),
analysis_config=IESSettings(inversion="subspace_exact"),
Expand All @@ -305,7 +308,7 @@ def test_update_snapshot(
smoother_update(
prior_ens,
posterior_ens,
list(ert_config.observations.keys()),
experiment.observation_keys,
list(ert_config.ensemble_config.parameters),
UpdateSettings(),
ESSettings(inversion="subspace"),
Expand All @@ -324,7 +327,7 @@ def test_update_snapshot(
assert sim_gen_kw != target_gen_kw

# Check that posterior is as expected
assert target_gen_kw == pytest.approx(expected_gen_kw)
assert target_gen_kw == pytest.approx(expected_gen_kw, rel=1e-5)


@pytest.mark.usefixtures("use_tmpdir")
Expand Down Expand Up @@ -380,7 +383,7 @@ def test_smoother_snapshot_alpha(
experiment = storage.create_experiment(
parameters=[uniform_parameter],
responses=[resp],
observations={"OBSERVATION": obs},
observations={"gen_data": obs},
)
prior_storage = storage.create_ensemble(
experiment,
Expand All @@ -405,13 +408,15 @@ def test_smoother_snapshot_alpha(
data = rng.uniform(0.8, 1, 3)
prior_storage.save_response(
"gen_data",
xr.Dataset(
{"values": (["name", "report_step", "index"], [[data]])},
coords={
"name": ["RESPONSE"],
"index": range(len(data)),
"report_step": [0],
},
polars.DataFrame(
{
"response_key": "RESPONSE",
"report_step": polars.Series(
np.full(len(data), 0), dtype=polars.UInt16
),
"index": polars.Series(range(len(data)), dtype=polars.UInt16),
"values": data,
}
),
iens,
)
Expand Down Expand Up @@ -512,19 +517,15 @@ def g(X):
grid.to_file("MY_EGRID.EGRID", "egrid")

resp = GenDataConfig(keys=["RESPONSE"])
obs = xr.Dataset(
obs = polars.DataFrame(
{
"observations": (
["report_step", "index"],
observations.reshape((1, num_observations)),
),
"std": (
["report_step", "index"],
observation_noise.reshape(1, num_observations),
),
},
coords={"report_step": [0], "index": np.arange(len(observations))},
attrs={"response": "RESPONSE"},
"response_key": "RESPONSE",
"observation_key": "OBSERVATION",
"report_step": 0,
"index": np.arange(len(observations)),
"observations": observations,
"std": observation_noise,
}
)

param_group = "PARAM_FIELD"
Expand All @@ -544,7 +545,7 @@ def g(X):
experiment = storage.create_experiment(
parameters=[config],
responses=[resp],
observations={"OBSERVATION": obs},
observations={"gen_data": obs},
)

prior_ensemble = storage.create_ensemble(
Expand All @@ -570,13 +571,13 @@ def g(X):

prior_ensemble.save_response(
"gen_data",
xr.Dataset(
{"values": (["name", "report_step", "index"], [[Y[:, iens]]])},
coords={
"name": ["RESPONSE"],
polars.DataFrame(
{
"response_key": "RESPONSE",
"report_step": 0,
"index": range(len(Y[:, iens])),
"report_step": [0],
},
"values": Y[:, iens],
}
),
iens,
)
Expand Down Expand Up @@ -750,19 +751,6 @@ def test_temporary_parameter_storage_with_inactive_fields(
np.testing.assert_array_equal(ds["values"].values[0], fields[iens]["values"])


def test_that_observations_keep_sorting(snake_oil_case_storage, snake_oil_storage):
"""
The order of the observations influence the update as it affects the
perturbations, so we make sure we maintain the order throughout.
"""
ert_config = snake_oil_case_storage
experiment = snake_oil_storage.get_experiment_by_name("ensemble-experiment")
prior_ens = experiment.get_ensemble_by_name("default_0")
assert list(ert_config.observations.keys()) == list(
prior_ens.experiment.observations.keys()
)


def _mock_load_observations_and_responses(
S,
observations,
Expand Down
24 changes: 9 additions & 15 deletions tests/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import hypothesis.strategies as st
import numpy as np
import polars
import pytest
import xarray as xr
from hypothesis import assume
Expand Down Expand Up @@ -88,25 +89,18 @@ def test_that_saving_empty_responses_fails_nicely(tmp_path):
ValueError,
match="Dataset for response group 'RESPONSE' must contain a 'values' variable",
):
ensemble.save_response(
"RESPONSE",
xr.Dataset(),
0,
)
ensemble.save_response("RESPONSE", polars.DataFrame(), 0)

# Test for dataset with 'values' but no actual data
empty_data = xr.Dataset(
empty_data = polars.DataFrame(
{
"values": (
["report_step", "index"],
np.array([], dtype=float).reshape(0, 0),
)
},
coords={
"index": np.array([], dtype=int),
"report_step": np.array([], dtype=int),
},
"response_key": [],
"report_step": [],
"index": [],
"values": [],
}
)

with pytest.raises(
ValueError,
match="Responses RESPONSE are empty. Cannot proceed with saving to storage.",
Expand Down
31 changes: 11 additions & 20 deletions tests/unit_tests/test_load_forward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from textwrap import dedent

import numpy as np
import polars
import pytest
from resdata.summary import Summary

Expand Down Expand Up @@ -166,13 +167,9 @@ def test_load_forward_model_gen_data(setup_case):

facade = LibresFacade(config)
facade.load_from_forward_model(prior_ensemble, [True], 0)
assert list(
prior_ensemble.load_responses("RESPONSE", (0,))
.sel(report_step=0, drop=True)
.to_dataframe()
.dropna()
.values.flatten()
) == [1.0, 3.0]
df = prior_ensemble.load_responses("gen_data", (0,))
filter_cond = polars.col("report_step").eq(0), polars.col("values").is_not_nan()
assert df.filter(filter_cond)["values"].to_list() == [1.0, 3.0]


def test_single_valued_gen_data_with_active_info_is_loaded(setup_case):
Expand All @@ -192,9 +189,8 @@ def test_single_valued_gen_data_with_active_info_is_loaded(setup_case):

facade = LibresFacade(config)
facade.load_from_forward_model(prior_ensemble, [True], 0)
assert list(
prior_ensemble.load_responses("RESPONSE", (0,)).to_dataframe().values.flatten()
) == [1.0]
df = prior_ensemble.load_responses("RESPONSE", (0,))
assert df["values"].to_list() == [1.0]


def test_that_all_deactivated_values_are_loaded(setup_case):
Expand All @@ -214,10 +210,8 @@ def test_that_all_deactivated_values_are_loaded(setup_case):

facade = LibresFacade(config)
facade.load_from_forward_model(prior_ensemble, [True], 0)
response = (
prior_ensemble.load_responses("RESPONSE", (0,)).to_dataframe().values.flatten()
)
assert np.isnan(response[0])
response = prior_ensemble.load_responses("RESPONSE", (0,))
assert np.isnan(response[0]["values"].to_list())
assert len(response) == 1


Expand Down Expand Up @@ -254,12 +248,9 @@ def test_loading_gen_data_without_restart(storage, run_paths, run_args):

facade = LibresFacade.from_config_file("config.ert")
facade.load_from_forward_model(prior_ensemble, [True], 0)
assert list(
prior_ensemble.load_responses("RESPONSE", (0,))
.to_dataframe()
.dropna()
.values.flatten()
) == [1.0, 3.0]
df = prior_ensemble.load_responses("RESPONSE", (0,))
df_no_nans = df.filter(polars.col("values").is_not_nan())
assert df_no_nans["values"].to_list() == [1.0, 3.0]


@pytest.mark.usefixtures("copy_snake_oil_case_storage")
Expand Down
13 changes: 6 additions & 7 deletions tests/unit_tests/test_summary_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,12 @@ def test_load_summary_response_restart_not_zero(
facade = LibresFacade.from_config_file("config.ert")
facade.load_from_forward_model(ensemble, [True], 0)

df = ensemble.load_responses("summary", (0,)).to_dataframe()
df = df.unstack(level="name")
df.columns = [col[1] for col in df.columns.values]
df.index = df.index.rename(
{"time": "Date", "realization": "Realization"}
).reorder_levels(["Realization", "Date"])
df = ensemble.load_responses("summary", (0,))
df = df.pivot(on="response_key", values="values")
df = df[df.columns[:17]]
df = df.rename({"time": "Date", "realization": "Realization"})

snapshot.assert_match(
df.dropna().iloc[:, :15].to_csv(),
df.to_pandas().to_csv(index=False),
"summary_restart",
)

0 comments on commit af82675

Please sign in to comment.