Skip to content

Commit

Permalink
Use strategy="nearest" for asof join
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Feb 4, 2025
1 parent c1ca814 commit fcdb7ca
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ def get_observations_and_responses(
*[k for k in response_cls.primary_key if k != "time"],
],
on="time",
strategy="nearest",
tolerance="1s",
)
else:
Expand Down
97 changes: 97 additions & 0 deletions tests/ert/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,103 @@ def test_write_transaction_overwrites(tmp_path):
assert path.read_bytes() == b"deadbeaf"


@pytest.mark.parametrize(
"perturb_observations, perturb_responses",
[
pytest.param(
False,
True,
id="Perturbed responses",
),
pytest.param(
True,
False,
id="Perturbed observations",
),
pytest.param(
True,
False,
id="Perturbed observations & responses",
),
],
)
def test_asof_joining_summary(tmp_path, perturb_observations, perturb_responses):
with open_storage(tmp_path, mode="w") as storage:
response_keys = ["FOPR", "FOPT_OP1", "FOPR:OP3", "FLAP", "F*"]
obs_keys = [f"o_{k}" for k in response_keys]
times = [datetime(2000, 1, 1, 1, 0)] * len(response_keys)
summary_observations = polars.DataFrame(
{
"observation_key": obs_keys,
"response_key": response_keys,
"time": polars.Series(
times,
dtype=polars.Datetime("ms"),
),
"observations": polars.Series(
[1] * len(response_keys),
dtype=polars.Float32,
),
"std": polars.Series(
[0.1] * len(response_keys),
dtype=polars.Float32,
),
}
)

experiment = storage.create_experiment(
responses=[SummaryConfig(keys=["*"], input_files=["not_relevant"])],
observations={"summary": summary_observations},
)

ensemble = storage.create_ensemble(
experiment, ensemble_size=1, iteration=0, name="prior"
)

summary_df = polars.DataFrame(
{
"response_key": response_keys,
"time": polars.Series(times, dtype=polars.Datetime("ms")),
"values": polars.Series(
[0.0, 1.0, 2.0, 3.0, 4.0], dtype=polars.Float32
),
}
)

ensemble.save_response("summary", summary_df, 0)
iens_active_index = np.array([0])

obs_and_responses_exact = ensemble.get_observations_and_responses(
obs_keys, iens_active_index
)

if perturb_responses:
perturbed_summary = summary_df.with_columns(
polars.when(polars.arange(0, summary_df.height) % 2 != 0)
.then(polars.col("time") + polars.duration(milliseconds=500))
.otherwise(polars.col("time") - polars.duration(milliseconds=500))
.alias("time")
)
ensemble.save_response("summary", perturbed_summary, 0)

if perturb_observations:
perturbed_observations = summary_observations.with_columns(
polars.when(polars.arange(0, summary_observations.height) % 2 != 0)
.then(polars.col("time") + polars.duration(milliseconds=500))
.otherwise(polars.col("time") - polars.duration(milliseconds=500))
.alias("time")
)
experiment.observations["summary"] = perturbed_observations

obs_and_responses_perturbed = ensemble.get_observations_and_responses(
obs_keys, iens_active_index
)

assert obs_and_responses_exact.drop("index").equals(
obs_and_responses_perturbed.drop("index")
)


@dataclass
class Ensemble:
uuid: UUID
Expand Down

0 comments on commit fcdb7ca

Please sign in to comment.