diff --git a/tests/performance_tests/test_memory_usage.py b/tests/performance_tests/test_memory_usage.py index 6385494bd8d..32c8a0db8d1 100644 --- a/tests/performance_tests/test_memory_usage.py +++ b/tests/performance_tests/test_memory_usage.py @@ -3,9 +3,9 @@ from typing import List import numpy as np +import polars import py import pytest -import xarray as xr from ert.analysis import smoother_update from ert.config import ErtConfig @@ -39,7 +39,7 @@ def poly_template(monkeypatch): @pytest.mark.flaky(reruns=5) -@pytest.mark.limit_memory("130 MB") +@pytest.mark.limit_memory("450 MB") @pytest.mark.integration_test def test_memory_smoothing(poly_template): ert_config = ErtConfig.from_file("poly.ert") @@ -57,7 +57,7 @@ def test_memory_smoothing(poly_template): smoother_update( prior_ens, posterior_ens, - list(ert_config.observations.keys()), + list(experiment.observation_keys), list(ert_config.ensemble_config.parameters), ) @@ -77,26 +77,28 @@ def fill_storage_with_data(poly_template: Path, ert_config: ErtConfig) -> None: realizations = list(range(ert_config.model_config.num_realizations)) for real in realizations: gendatas = [] - for _, obs in ert_config.observations.items(): - data_key = obs.attrs["response"] - if data_key != "summary": - obs_highest_index_used = max(obs.index.values) - gendatas.append( - make_gen_data(int(obs_highest_index_used) + 1).expand_dims( - name=[data_key] - ) - ) - else: - obs_time_list = ens_config.refcase.all_dates - source.save_response( - data_key, - make_summary_data(["summary"], obs_time_list), - real, - ) + gen_obs = ert_config.observations["gen_data"] + for response_key, df in gen_obs.group_by("response_key"): + gendata_df = make_gen_data(df["index"].max() + 1) + gendata_df = gendata_df.insert_column( + 0, + polars.Series(np.full(len(gendata_df), response_key)).alias( + "response_key" + ), + ) + gendatas.append(gendata_df) + + source.save_response("gen_data", polars.concat(gendatas), real) + + obs_time_list = ens_config.refcase.all_dates + + summary_keys = ert_config.observations["summary"]["response_key"].unique( + maintain_order=True + ) source.save_response( - "gen_data", - xr.concat(gendatas, dim="name"), + "summary", + make_summary_data(summary_keys, obs_time_list), real, ) @@ -111,11 +113,14 @@ def fill_storage_with_data(poly_template: Path, ert_config: ErtConfig) -> None: ) -def make_gen_data(obs: int, min_val: float = 0, max_val: float = 5) -> xr.Dataset: +def make_gen_data(obs: int, min_val: float = 0, max_val: float = 5) -> polars.DataFrame: data = np.random.default_rng().uniform(min_val, max_val, obs) - return xr.Dataset( - {"values": (["report_step", "index"], [data])}, - coords={"index": range(len(data)), "report_step": [0]}, + return polars.DataFrame( + { + "report_step": polars.Series(np.full(len(data), 0), dtype=polars.UInt16), + "index": polars.Series(range(len(data)), dtype=polars.UInt16), + "values": data, + } ) @@ -124,11 +129,15 @@ def make_summary_data( dates, min_val: float = 0, max_val: float = 5, -) -> xr.Dataset: - data = np.random.default_rng().uniform( - min_val, max_val, (len(obs_keys), len(dates)) - ) - return xr.Dataset( - {"values": (["name", "time"], data)}, - coords={"time": dates, "name": obs_keys}, +) -> polars.DataFrame: + data = np.random.default_rng().uniform(min_val, max_val, len(obs_keys) * len(dates)) + + return polars.DataFrame( + { + "response_key": np.repeat(obs_keys, len(dates)), + "time": polars.Series( + np.tile(dates, len(obs_keys)).tolist() + ).dt.cast_time_unit("ms"), + "values": data, + } )