Skip to content

Commit

Permalink
Simplify alignment check in get_observations_and_responses
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Jun 17, 2024
1 parent 054d9aa commit c313ca5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
29 changes: 14 additions & 15 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,22 +1183,21 @@ def get_observations_and_responses(
)
for x in combined[index].coords.to_index()
]
).reshape(-1, 1)
obs_vals_1d = combined["observations"].data.reshape(-1, 1)
std_vals_1d = combined["std"].data.reshape(-1, 1)
)
obs_vals_1d = combined["observations"].data
std_vals_1d = combined["std"].data

num_obs_names = len(obs_vals_1d)
obs_names_1d = np.full((len(std_vals_1d), 1), obs_name)
num_obs = len(obs_vals_1d)
obs_names_1d = np.array([obs_name] * num_obs)

if (
len(key_index_1d) != num_obs_names
or len(response_vals_per_real) != num_obs_names
or len(obs_names_1d) != num_obs_names
or len(std_vals_1d) != num_obs_names
len(key_index_1d) != num_obs
or response_vals_per_real.shape[0] != num_obs
or len(std_vals_1d) != num_obs
):
raise IndexError(
"Axis 0 misalignment, expected axis0 length to "
f"correspond to observation names {num_obs_names}. Got:\n"
"Axis 0 misalignment, expected axis 0 length to "
f"correspond to observation names {num_obs}. Got:\n"
f"len(response_vals_per_real)={len(response_vals_per_real)}\n"
f"len(obs_names_1d)={len(obs_names_1d)}\n"
f"len(std_vals_1d)={len(std_vals_1d)}"
Expand All @@ -1217,10 +1216,10 @@ def get_observations_and_responses(

combined_np_long = np.concatenate(
[
obs_names_1d,
key_index_1d,
obs_vals_1d,
std_vals_1d,
obs_names_1d.reshape(-1, 1),
key_index_1d.reshape(-1, 1),
obs_vals_1d.reshape(-1, 1),
std_vals_1d.reshape(-1, 1),
response_vals_per_real,
],
axis=1,
Expand Down
21 changes: 11 additions & 10 deletions tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,22 @@ def uniform_parameter():

@pytest.fixture
def obs():
observations = np.array([1.0, 1.0, 1.0])
errors = np.array([0.1, 1.0, 10.0])
return xr.Dataset(
{
"observations": (
["name", "obs_name", "report_step", "index"],
[[[[1.0, 1.0, 1.0]]]],
["name", "obs_name", "index", "report_step"],
np.reshape(observations, (1, 1, 3, 1)),
),
"std": (
["name", "obs_name", "report_step", "index"],
[[[[0.1, 1.0, 10.0]]]],
["name", "obs_name", "index", "report_step"],
np.reshape(errors, (1, 1, 3, 1)),
),
},
coords={
"name": ["RESPONSE"],
"obs_name": ["OBSERVATION"],
"name": ["RESPONSE"], # Has to correspond to actual response name
"index": [0, 1, 2],
"report_step": [0],
},
Expand Down Expand Up @@ -517,18 +519,17 @@ def g(X):
obs = xr.Dataset(
{
"observations": (
["report_step", "index"],
observations.reshape((1, num_observations)),
["index", "report_step"],
observations.reshape((num_observations, 1)),
),
"std": (
["report_step", "index"],
observation_noise.reshape(1, num_observations),
["index", "report_step"],
observation_noise.reshape(num_observations, 1),
),
},
coords={"report_step": [0], "index": np.arange(len(observations))},
attrs={"response": "gen_data"},
)

obs = obs.expand_dims({"obs_name": ["OBSERVATION"]})
obs = obs.expand_dims({"name": ["RESPONSE"]})

Expand Down

0 comments on commit c313ca5

Please sign in to comment.