From c313ca5201f5671437c5c95fa94fdcfb13f880af Mon Sep 17 00:00:00 2001 From: Feda Curic Date: Mon, 17 Jun 2024 14:14:18 +0200 Subject: [PATCH] Simplify alignment check in get_observations_and_responses --- src/ert/storage/local_ensemble.py | 29 ++++++++++----------- tests/unit_tests/analysis/test_es_update.py | 21 ++++++++------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index f019f163fa0..780c1b12412 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -1183,22 +1183,21 @@ def get_observations_and_responses( ) for x in combined[index].coords.to_index() ] - ).reshape(-1, 1) - obs_vals_1d = combined["observations"].data.reshape(-1, 1) - std_vals_1d = combined["std"].data.reshape(-1, 1) + ) + obs_vals_1d = combined["observations"].data + std_vals_1d = combined["std"].data - num_obs_names = len(obs_vals_1d) - obs_names_1d = np.full((len(std_vals_1d), 1), obs_name) + num_obs = len(obs_vals_1d) + obs_names_1d = np.array([obs_name] * num_obs) if ( - len(key_index_1d) != num_obs_names - or len(response_vals_per_real) != num_obs_names - or len(obs_names_1d) != num_obs_names - or len(std_vals_1d) != num_obs_names + len(key_index_1d) != num_obs + or response_vals_per_real.shape[0] != num_obs + or len(std_vals_1d) != num_obs ): raise IndexError( - "Axis 0 misalignment, expected axis0 length to " - f"correspond to observation names {num_obs_names}. Got:\n" + "Axis 0 misalignment, expected axis 0 length to " + f"correspond to observation names {num_obs}. Got:\n" f"len(response_vals_per_real)={len(response_vals_per_real)}\n" f"len(obs_names_1d)={len(obs_names_1d)}\n" f"len(std_vals_1d)={len(std_vals_1d)}" @@ -1217,10 +1216,10 @@ def get_observations_and_responses( combined_np_long = np.concatenate( [ - obs_names_1d, - key_index_1d, - obs_vals_1d, - std_vals_1d, + obs_names_1d.reshape(-1, 1), + key_index_1d.reshape(-1, 1), + obs_vals_1d.reshape(-1, 1), + std_vals_1d.reshape(-1, 1), response_vals_per_real, ], axis=1, diff --git a/tests/unit_tests/analysis/test_es_update.py b/tests/unit_tests/analysis/test_es_update.py index 8f46b5b19bc..df45705c948 100644 --- a/tests/unit_tests/analysis/test_es_update.py +++ b/tests/unit_tests/analysis/test_es_update.py @@ -47,20 +47,22 @@ def uniform_parameter(): @pytest.fixture def obs(): + observations = np.array([1.0, 1.0, 1.0]) + errors = np.array([0.1, 1.0, 10.0]) return xr.Dataset( { "observations": ( - ["name", "obs_name", "report_step", "index"], - [[[[1.0, 1.0, 1.0]]]], + ["name", "obs_name", "index", "report_step"], + np.reshape(observations, (1, 1, 3, 1)), ), "std": ( - ["name", "obs_name", "report_step", "index"], - [[[[0.1, 1.0, 10.0]]]], + ["name", "obs_name", "index", "report_step"], + np.reshape(errors, (1, 1, 3, 1)), ), }, coords={ + "name": ["RESPONSE"], "obs_name": ["OBSERVATION"], - "name": ["RESPONSE"], # Has to correspond to actual response name "index": [0, 1, 2], "report_step": [0], }, @@ -517,18 +519,17 @@ def g(X): obs = xr.Dataset( { "observations": ( - ["report_step", "index"], - observations.reshape((1, num_observations)), + ["index", "report_step"], + observations.reshape((num_observations, 1)), ), "std": ( - ["report_step", "index"], - observation_noise.reshape(1, num_observations), + ["index", "report_step"], + observation_noise.reshape(num_observations, 1), ), }, coords={"report_step": [0], "index": np.arange(len(observations))}, attrs={"response": "gen_data"}, ) - obs = obs.expand_dims({"obs_name": ["OBSERVATION"]}) obs = obs.expand_dims({"name": ["RESPONSE"]})