From 9e337dd1095dd9f3eb7ff2afe52f63399e1be81b Mon Sep 17 00:00:00 2001 From: Feda Curic Date: Mon, 17 Jun 2024 14:14:18 +0200 Subject: [PATCH] Simplify alignment check in get_observations_and_responses --- src/ert/storage/local_ensemble.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index f019f163fa0..780c1b12412 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -1183,22 +1183,21 @@ def get_observations_and_responses( ) for x in combined[index].coords.to_index() ] - ).reshape(-1, 1) - obs_vals_1d = combined["observations"].data.reshape(-1, 1) - std_vals_1d = combined["std"].data.reshape(-1, 1) + ) + obs_vals_1d = combined["observations"].data + std_vals_1d = combined["std"].data - num_obs_names = len(obs_vals_1d) - obs_names_1d = np.full((len(std_vals_1d), 1), obs_name) + num_obs = len(obs_vals_1d) + obs_names_1d = np.array([obs_name] * num_obs) if ( - len(key_index_1d) != num_obs_names - or len(response_vals_per_real) != num_obs_names - or len(obs_names_1d) != num_obs_names - or len(std_vals_1d) != num_obs_names + len(key_index_1d) != num_obs + or response_vals_per_real.shape[0] != num_obs + or len(std_vals_1d) != num_obs ): raise IndexError( - "Axis 0 misalignment, expected axis0 length to " - f"correspond to observation names {num_obs_names}. Got:\n" + "Axis 0 misalignment, expected axis 0 length to " + f"correspond to observation names {num_obs}. Got:\n" f"len(response_vals_per_real)={len(response_vals_per_real)}\n" f"len(obs_names_1d)={len(obs_names_1d)}\n" f"len(std_vals_1d)={len(std_vals_1d)}" @@ -1217,10 +1216,10 @@ def get_observations_and_responses( combined_np_long = np.concatenate( [ - obs_names_1d, - key_index_1d, - obs_vals_1d, - std_vals_1d, + obs_names_1d.reshape(-1, 1), + key_index_1d.reshape(-1, 1), + obs_vals_1d.reshape(-1, 1), + std_vals_1d.reshape(-1, 1), response_vals_per_real, ], axis=1,