Skip to content

Commit

Permalink
Make test_memory_usage work
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed Sep 16, 2024
1 parent af82675 commit 7d4d4f3
Showing 1 changed file with 41 additions and 32 deletions.
73 changes: 41 additions & 32 deletions tests/performance_tests/test_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import List

import numpy as np
import polars
import py
import pytest
import xarray as xr

from ert.analysis import smoother_update
from ert.config import ErtConfig
Expand Down Expand Up @@ -39,7 +39,7 @@ def poly_template(monkeypatch):


@pytest.mark.flaky(reruns=5)
@pytest.mark.limit_memory("130 MB")
@pytest.mark.limit_memory("450 MB")
@pytest.mark.integration_test
def test_memory_smoothing(poly_template):
ert_config = ErtConfig.from_file("poly.ert")
Expand All @@ -57,7 +57,7 @@ def test_memory_smoothing(poly_template):
smoother_update(
prior_ens,
posterior_ens,
list(ert_config.observations.keys()),
list(experiment.observation_keys),
list(ert_config.ensemble_config.parameters),
)

Expand All @@ -77,26 +77,28 @@ def fill_storage_with_data(poly_template: Path, ert_config: ErtConfig) -> None:
realizations = list(range(ert_config.model_config.num_realizations))
for real in realizations:
gendatas = []
for _, obs in ert_config.observations.items():
data_key = obs.attrs["response"]
if data_key != "summary":
obs_highest_index_used = max(obs.index.values)
gendatas.append(
make_gen_data(int(obs_highest_index_used) + 1).expand_dims(
name=[data_key]
)
)
else:
obs_time_list = ens_config.refcase.all_dates
source.save_response(
data_key,
make_summary_data(["summary"], obs_time_list),
real,
)
gen_obs = ert_config.observations["gen_data"]
for response_key, df in gen_obs.group_by("response_key"):
gendata_df = make_gen_data(df["index"].max() + 1)
gendata_df = gendata_df.insert_column(
0,
polars.Series(np.full(len(gendata_df), response_key)).alias(
"response_key"
),
)
gendatas.append(gendata_df)

source.save_response("gen_data", polars.concat(gendatas), real)

obs_time_list = ens_config.refcase.all_dates

summary_keys = ert_config.observations["summary"]["response_key"].unique(
maintain_order=True
)

source.save_response(
"gen_data",
xr.concat(gendatas, dim="name"),
"summary",
make_summary_data(summary_keys, obs_time_list),
real,
)

Expand All @@ -111,11 +113,14 @@ def fill_storage_with_data(poly_template: Path, ert_config: ErtConfig) -> None:
)


def make_gen_data(obs: int, min_val: float = 0, max_val: float = 5) -> xr.Dataset:
def make_gen_data(obs: int, min_val: float = 0, max_val: float = 5) -> polars.DataFrame:
data = np.random.default_rng().uniform(min_val, max_val, obs)
return xr.Dataset(
{"values": (["report_step", "index"], [data])},
coords={"index": range(len(data)), "report_step": [0]},
return polars.DataFrame(
{
"report_step": polars.Series(np.full(len(data), 0), dtype=polars.UInt16),
"index": polars.Series(range(len(data)), dtype=polars.UInt16),
"values": data,
}
)


Expand All @@ -124,11 +129,15 @@ def make_summary_data(
dates,
min_val: float = 0,
max_val: float = 5,
) -> xr.Dataset:
data = np.random.default_rng().uniform(
min_val, max_val, (len(obs_keys), len(dates))
)
return xr.Dataset(
{"values": (["name", "time"], data)},
coords={"time": dates, "name": obs_keys},
) -> polars.DataFrame:
data = np.random.default_rng().uniform(min_val, max_val, len(obs_keys) * len(dates))

return polars.DataFrame(
{
"response_key": np.repeat(obs_keys, len(dates)),
"time": polars.Series(
np.tile(dates, len(obs_keys)).tolist()
).dt.cast_time_unit("ms"),
"values": data,
}
)

0 comments on commit 7d4d4f3

Please sign in to comment.