diff --git a/tests/ert/performance_tests/test_dark_storage_performance.py b/tests/ert/performance_tests/test_dark_storage_performance.py index 4f5bb2e1026..758530d9cbb 100644 --- a/tests/ert/performance_tests/test_dark_storage_performance.py +++ b/tests/ert/performance_tests/test_dark_storage_performance.py @@ -1,19 +1,48 @@ +import contextlib +import gc import io +import os +import time from asyncio import get_event_loop +from datetime import datetime, timedelta from typing import Awaitable, TypeVar +from urllib.parse import quote +import memray import numpy as np import pandas as pd +import polars import pytest +from httpx import RequestError +from starlette.testclient import TestClient -from ert.config import ErtConfig +from ert.config import ErtConfig, SummaryConfig +from ert.dark_storage import enkf +from ert.dark_storage.app import app from ert.dark_storage.endpoints import ensembles, experiments, records +from ert.gui.tools.plot.plot_api import PlotApi from ert.libres_facade import LibresFacade +from ert.services import StorageService from ert.storage import open_storage T = TypeVar("T") +@pytest.fixture(autouse=True) +def use_testclient(monkeypatch): + client = TestClient(app) + monkeypatch.setattr(StorageService, "session", lambda: client) + + def test_escape(s: str) -> str: + """ + Workaround for issue with TestClient: + https://github.com/encode/starlette/issues/1060 + """ + return quote(quote(quote(s, safe=""))) + + PlotApi.escape = test_escape + + def run_in_loop(coro: Awaitable[T]) -> T: return get_event_loop().run_until_complete(coro) @@ -178,3 +207,177 @@ def test_direct_dark_performance_with_storage( ensemble_id_default = ensemble_id benchmark(function, storage, ensemble_id_default, key, template_config) + + +@pytest.fixture +def api_and_storage(monkeypatch, tmp_path): + with open_storage(tmp_path / "storage", mode="w") as storage: + monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup") + monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path) + api = PlotApi() + yield api, storage + if enkf._storage is not None: + enkf._storage.close() + enkf._storage = None + gc.collect() + + +@pytest.fixture +def api_and_snake_oil_storage(snake_oil_case_storage, monkeypatch): + with open_storage(snake_oil_case_storage.ens_path, mode="r") as storage: + monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup") + monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path) + + api = PlotApi() + yield api, storage + + if enkf._storage is not None: + enkf._storage.close() + enkf._storage = None + gc.collect() + + +@pytest.mark.parametrize( + "num_reals, num_dates, num_keys, max_memory_mb", + [ # Tested 24.11.22 on macbook pro M1 max + # (xr = tested on previous ert using xarray to store responses) + (1, 100, 100, 1200), # 790MiB local, xr: 791, MiB + (1000, 100, 100, 1500), # 809MiB local, 879MiB linux-3.11, xr: 1107MiB + # (Cases below are more realistic at up to 200realizations) + # Not to be run these on GHA runners + # (2000, 100, 100, 1950), # 1607MiB local, 1716MiB linux3.12, 1863 on linux3.11, xr: 2186MiB + # (2, 5803, 11787, 5500), # 4657MiB local, xr: 10115MiB + # (10, 5803, 11787, 13500), # 10036MiB local, 12803MiB mac-3.12, xr: 46715MiB + ], +) +def test_plot_api_big_summary_memory_usage( + num_reals, num_dates, num_keys, max_memory_mb, use_tmpdir, api_and_storage +): + api, storage = api_and_storage + + dates = [] + + for i in range(num_keys): + dates += [datetime(2000, 1, 1) + timedelta(days=i)] * num_dates + + dates_df = polars.Series(dates, dtype=polars.Datetime).dt.cast_time_unit("ms") + + keys_df = polars.Series([f"K{i}" for i in range(num_keys)]) + values_df = polars.Series(list(range(num_keys * num_dates)), dtype=polars.Float32) + + big_summary = polars.DataFrame( + { + "response_key": polars.concat([keys_df] * num_dates), + "time": dates_df, + "values": values_df, + } + ) + + experiment = storage.create_experiment( + parameters=[], + responses=[ + SummaryConfig( + name="summary", + input_files=["CASE.UNSMRY", "CASE.SMSPEC"], + keys=keys_df, + ) + ], + ) + + ensemble = experiment.create_ensemble(ensemble_size=num_reals, name="bigboi") + for real in range(ensemble.ensemble_size): + ensemble.save_response("summary", big_summary.clone(), real) + + with memray.Tracker("memray.bin", follow_fork=True, native_traces=True): + # Initialize plotter window + all_keys = {k.key for k in api.all_data_type_keys()} + all_ensembles = [e.id for e in api.get_all_ensembles()] + assert set(keys_df.to_list()) == set(all_keys) + + # call updatePlot() + ensemble_to_data_map: dict[str, pd.DataFrame] = {} + sample_key = keys_df.sample(1).item() + for ensemble in all_ensembles: + ensemble_to_data_map[ensemble] = api.data_for_key(ensemble, sample_key) + + for ensemble in all_ensembles: + data = ensemble_to_data_map[ensemble] + + # Transpose it twice as done in plotter + # (should ideally be avoided) + _ = data.T + _ = data.T + + stats = memray._memray.compute_statistics("memray.bin") + os.remove("memray.bin") + total_memory_usage = stats.total_memory_allocated / (1024**2) + assert total_memory_usage < max_memory_mb + + +def test_plotter_on_all_snake_oil_responses_time(api_and_snake_oil_storage): + api, _ = api_and_snake_oil_storage + t0 = time.time() + key_infos = api.all_data_type_keys() + all_ensembles = api.get_all_ensembles() + t1 = time.time() + # Cycle through all ensembles and get all responses + for key_info in key_infos: + for ensemble in all_ensembles: + api.data_for_key(ensemble_id=ensemble.id, key=key_info.key) + + if key_info.observations: + with contextlib.suppress(RequestError, TimeoutError): + api.observations_for_key( + [ens.id for ens in all_ensembles], key_info.key + ) + + # Note: Does not test for fields + if not (str(key_info.key).endswith("H") or "H:" in str(key_info.key)): + with contextlib.suppress(RequestError, TimeoutError): + api.history_data( + key_info.key, + [e.id for e in all_ensembles], + ) + + t2 = time.time() + time_to_get_metadata = t1 - t0 + time_to_cycle_through_responses = t2 - t1 + + # Local times were about 10% of the asserted times + assert time_to_get_metadata < 1 + assert time_to_cycle_through_responses < 14 + + +def test_plotter_on_all_snake_oil_responses_memory(api_and_snake_oil_storage): + api, _ = api_and_snake_oil_storage + + with memray.Tracker("memray.bin", follow_fork=True, native_traces=True): + key_infos = api.all_data_type_keys() + all_ensembles = api.get_all_ensembles() + # Cycle through all ensembles and get all responses + for key_info in key_infos: + for ensemble in all_ensembles: + api.data_for_key(ensemble_id=ensemble.id, key=key_info.key) + + if key_info.observations: + with contextlib.suppress(RequestError, TimeoutError): + api.observations_for_key( + [ens.id for ens in all_ensembles], key_info.key + ) + + # Note: Does not test for fields + if not (str(key_info.key).endswith("H") or "H:" in str(key_info.key)): + with contextlib.suppress(RequestError, TimeoutError): + api.history_data( + key_info.key, + [e.id for e in all_ensembles], + ) + + stats = memray._memray.compute_statistics("memray.bin") + os.remove("memray.bin") + total_memory_mb = stats.total_memory_allocated / (1024**2) + peak_memory_mb = stats.peak_memory_allocated / (1024**2) + + # thresholds are set to about 1.5x local memory used + assert total_memory_mb < 5000 + assert peak_memory_mb < 1500 diff --git a/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py b/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py index a854179bfca..49ebfddd3a4 100644 --- a/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py +++ b/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py @@ -1,19 +1,13 @@ -import contextlib import gc -import os -import time -from datetime import datetime, timedelta +from datetime import datetime from textwrap import dedent -from typing import Dict from urllib.parse import quote import httpx -import memray import pandas as pd import polars import pytest import xarray as xr -from httpx import RequestError from pandas.testing import assert_frame_equal from starlette.testclient import TestClient @@ -201,167 +195,6 @@ def api_and_storage(monkeypatch, tmp_path): gc.collect() -@pytest.fixture -def api_and_snake_oil_storage(snake_oil_case_storage, monkeypatch): - with open_storage(snake_oil_case_storage.ens_path, mode="r") as storage: - monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup") - monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path) - - api = PlotApi() - yield api, storage - - if enkf._storage is not None: - enkf._storage.close() - enkf._storage = None - gc.collect() - - -@pytest.mark.parametrize( - "num_reals, num_dates, num_keys, max_memory_mb", - [ # Tested 24.11.22 on macbook pro M1 max - # (xr = tested on previous ert using xarray to store responses) - (1, 100, 100, 1200), # 790MiB local, xr: 791, MiB - (1000, 100, 100, 1500), # 809MiB local, 879MiB linux-3.11, xr: 1107MiB - # (Cases below are more realistic at up to 200realizations) - # Not to be run these on GHA runners - # (2000, 100, 100, 1950), # 1607MiB local, 1716MiB linux3.12, 1863 on linux3.11, xr: 2186MiB - # (2, 5803, 11787, 5500), # 4657MiB local, xr: 10115MiB - # (10, 5803, 11787, 13500), # 10036MiB local, 12803MiB mac-3.12, xr: 46715MiB - ], -) -def test_plot_api_big_summary_memory_usage( - num_reals, num_dates, num_keys, max_memory_mb, use_tmpdir, api_and_storage -): - api, storage = api_and_storage - - dates = [] - - for i in range(num_keys): - dates += [datetime(2000, 1, 1) + timedelta(days=i)] * num_dates - - dates_df = polars.Series(dates, dtype=polars.Datetime).dt.cast_time_unit("ms") - - keys_df = polars.Series([f"K{i}" for i in range(num_keys)]) - values_df = polars.Series(list(range(num_keys * num_dates)), dtype=polars.Float32) - - big_summary = polars.DataFrame( - { - "response_key": polars.concat([keys_df] * num_dates), - "time": dates_df, - "values": values_df, - } - ) - - experiment = storage.create_experiment( - parameters=[], - responses=[ - SummaryConfig( - name="summary", - input_files=["CASE.UNSMRY", "CASE.SMSPEC"], - keys=keys_df, - ) - ], - ) - - ensemble = experiment.create_ensemble(ensemble_size=num_reals, name="bigboi") - for real in range(ensemble.ensemble_size): - ensemble.save_response("summary", big_summary.clone(), real) - - with memray.Tracker("memray.bin", follow_fork=True, native_traces=True): - # Initialize plotter window - all_keys = {k.key for k in api.all_data_type_keys()} - all_ensembles = [e.id for e in api.get_all_ensembles()] - assert set(keys_df.to_list()) == set(all_keys) - - # call updatePlot() - ensemble_to_data_map: Dict[str, pd.DataFrame] = {} - sample_key = keys_df.sample(1).item() - for ensemble in all_ensembles: - ensemble_to_data_map[ensemble] = api.data_for_key(ensemble, sample_key) - - for ensemble in all_ensembles: - data = ensemble_to_data_map[ensemble] - - # Transpose it twice as done in plotter - # (should ideally be avoided) - _ = data.T - _ = data.T - - stats = memray._memray.compute_statistics("memray.bin") - os.remove("memray.bin") - total_memory_usage = stats.total_memory_allocated / (1024**2) - assert total_memory_usage < max_memory_mb - - -def test_plotter_on_all_snake_oil_responses_time(api_and_snake_oil_storage): - api, _ = api_and_snake_oil_storage - t0 = time.time() - key_infos = api.all_data_type_keys() - all_ensembles = api.get_all_ensembles() - t1 = time.time() - # Cycle through all ensembles and get all responses - for key_info in key_infos: - for ensemble in all_ensembles: - api.data_for_key(ensemble_id=ensemble.id, key=key_info.key) - - if key_info.observations: - with contextlib.suppress(RequestError, TimeoutError): - api.observations_for_key( - [ens.id for ens in all_ensembles], key_info.key - ) - - # Note: Does not test for fields - if not (str(key_info.key).endswith("H") or "H:" in str(key_info.key)): - with contextlib.suppress(RequestError, TimeoutError): - api.history_data( - key_info.key, - [e.id for e in all_ensembles], - ) - - t2 = time.time() - time_to_get_metadata = t1 - t0 - time_to_cycle_through_responses = t2 - t1 - - # Local times were about 10% of the asserted times - assert time_to_get_metadata < 1 - assert time_to_cycle_through_responses < 14 - - -def test_plotter_on_all_snake_oil_responses_memory(api_and_snake_oil_storage): - api, _ = api_and_snake_oil_storage - - with memray.Tracker("memray.bin", follow_fork=True, native_traces=True): - key_infos = api.all_data_type_keys() - all_ensembles = api.get_all_ensembles() - # Cycle through all ensembles and get all responses - for key_info in key_infos: - for ensemble in all_ensembles: - api.data_for_key(ensemble_id=ensemble.id, key=key_info.key) - - if key_info.observations: - with contextlib.suppress(RequestError, TimeoutError): - api.observations_for_key( - [ens.id for ens in all_ensembles], key_info.key - ) - - # Note: Does not test for fields - if not (str(key_info.key).endswith("H") or "H:" in str(key_info.key)): - with contextlib.suppress(RequestError, TimeoutError): - api.history_data( - key_info.key, - [e.id for e in all_ensembles], - ) - - stats = memray._memray.compute_statistics("memray.bin") - os.remove("memray.bin") - total_memory_mb = stats.total_memory_allocated / (1024**2) - peak_memory_mb = stats.peak_memory_allocated / (1024**2) - - # thresholds are set to about 1.5x local memory used - assert total_memory_mb < 5000 - assert peak_memory_mb < 1500 - - def test_plot_api_handles_urlescape(api_and_storage): api, storage = api_and_storage key = "WBHP:46/3-7S" diff --git a/tests/ert/unit_tests/test_libres_facade.py b/tests/ert/unit_tests/test_libres_facade.py index bcc3efa098a..98ecd91a359 100644 --- a/tests/ert/unit_tests/test_libres_facade.py +++ b/tests/ert/unit_tests/test_libres_facade.py @@ -31,6 +31,7 @@ def empty_case(facade, storage): ) +@pytest.mark.integration_test def test_keyword_type_checks(snake_oil_default_storage): assert ( "BPR:1,3,8" @@ -40,6 +41,7 @@ def test_keyword_type_checks(snake_oil_default_storage): ) +@pytest.mark.integration_test def test_keyword_type_checks_missing_key(snake_oil_default_storage): assert ( "nokey"