Skip to content

Commit

Permalink
Move performance tests out of unit_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Dec 2, 2024
1 parent 6197ed4 commit 08e98c5
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 169 deletions.
205 changes: 204 additions & 1 deletion tests/ert/performance_tests/test_dark_storage_performance.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,48 @@
import contextlib
import gc
import io
import os
import time
from asyncio import get_event_loop
from datetime import datetime, timedelta
from typing import Awaitable, TypeVar
from urllib.parse import quote

import memray
import numpy as np
import pandas as pd
import polars
import pytest
from httpx import RequestError
from starlette.testclient import TestClient

from ert.config import ErtConfig
from ert.config import ErtConfig, SummaryConfig
from ert.dark_storage import enkf
from ert.dark_storage.app import app
from ert.dark_storage.endpoints import ensembles, experiments, records
from ert.gui.tools.plot.plot_api import PlotApi
from ert.libres_facade import LibresFacade
from ert.services import StorageService
from ert.storage import open_storage

T = TypeVar("T")


@pytest.fixture(autouse=True)
def use_testclient(monkeypatch):
client = TestClient(app)
monkeypatch.setattr(StorageService, "session", lambda: client)

def test_escape(s: str) -> str:
"""
Workaround for issue with TestClient:
https://github.com/encode/starlette/issues/1060
"""
return quote(quote(quote(s, safe="")))

PlotApi.escape = test_escape


def run_in_loop(coro: Awaitable[T]) -> T:
return get_event_loop().run_until_complete(coro)

Expand Down Expand Up @@ -178,3 +207,177 @@ def test_direct_dark_performance_with_storage(
ensemble_id_default = ensemble_id

benchmark(function, storage, ensemble_id_default, key, template_config)


@pytest.fixture
def api_and_storage(monkeypatch, tmp_path):
with open_storage(tmp_path / "storage", mode="w") as storage:
monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup")
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path)
api = PlotApi()
yield api, storage
if enkf._storage is not None:
enkf._storage.close()
enkf._storage = None
gc.collect()


@pytest.fixture
def api_and_snake_oil_storage(snake_oil_case_storage, monkeypatch):
with open_storage(snake_oil_case_storage.ens_path, mode="r") as storage:
monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup")
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path)

api = PlotApi()
yield api, storage

if enkf._storage is not None:
enkf._storage.close()
enkf._storage = None
gc.collect()


@pytest.mark.parametrize(
"num_reals, num_dates, num_keys, max_memory_mb",
[ # Tested 24.11.22 on macbook pro M1 max
# (xr = tested on previous ert using xarray to store responses)
(1, 100, 100, 1200), # 790MiB local, xr: 791, MiB
(1000, 100, 100, 1500), # 809MiB local, 879MiB linux-3.11, xr: 1107MiB
# (Cases below are more realistic at up to 200realizations)
# Not to be run these on GHA runners
# (2000, 100, 100, 1950), # 1607MiB local, 1716MiB linux3.12, 1863 on linux3.11, xr: 2186MiB
# (2, 5803, 11787, 5500), # 4657MiB local, xr: 10115MiB
# (10, 5803, 11787, 13500), # 10036MiB local, 12803MiB mac-3.12, xr: 46715MiB
],
)
def test_plot_api_big_summary_memory_usage(
num_reals, num_dates, num_keys, max_memory_mb, use_tmpdir, api_and_storage
):
api, storage = api_and_storage

dates = []

for i in range(num_keys):
dates += [datetime(2000, 1, 1) + timedelta(days=i)] * num_dates

dates_df = polars.Series(dates, dtype=polars.Datetime).dt.cast_time_unit("ms")

keys_df = polars.Series([f"K{i}" for i in range(num_keys)])
values_df = polars.Series(list(range(num_keys * num_dates)), dtype=polars.Float32)

big_summary = polars.DataFrame(
{
"response_key": polars.concat([keys_df] * num_dates),
"time": dates_df,
"values": values_df,
}
)

experiment = storage.create_experiment(
parameters=[],
responses=[
SummaryConfig(
name="summary",
input_files=["CASE.UNSMRY", "CASE.SMSPEC"],
keys=keys_df,
)
],
)

ensemble = experiment.create_ensemble(ensemble_size=num_reals, name="bigboi")
for real in range(ensemble.ensemble_size):
ensemble.save_response("summary", big_summary.clone(), real)

with memray.Tracker("memray.bin", follow_fork=True, native_traces=True):
# Initialize plotter window
all_keys = {k.key for k in api.all_data_type_keys()}
all_ensembles = [e.id for e in api.get_all_ensembles()]
assert set(keys_df.to_list()) == set(all_keys)

# call updatePlot()
ensemble_to_data_map: dict[str, pd.DataFrame] = {}
sample_key = keys_df.sample(1).item()
for ensemble in all_ensembles:
ensemble_to_data_map[ensemble] = api.data_for_key(ensemble, sample_key)

for ensemble in all_ensembles:
data = ensemble_to_data_map[ensemble]

# Transpose it twice as done in plotter
# (should ideally be avoided)
_ = data.T
_ = data.T

stats = memray._memray.compute_statistics("memray.bin")
os.remove("memray.bin")
total_memory_usage = stats.total_memory_allocated / (1024**2)
assert total_memory_usage < max_memory_mb


def test_plotter_on_all_snake_oil_responses_time(api_and_snake_oil_storage):
api, _ = api_and_snake_oil_storage
t0 = time.time()
key_infos = api.all_data_type_keys()
all_ensembles = api.get_all_ensembles()
t1 = time.time()
# Cycle through all ensembles and get all responses
for key_info in key_infos:
for ensemble in all_ensembles:
api.data_for_key(ensemble_id=ensemble.id, key=key_info.key)

if key_info.observations:
with contextlib.suppress(RequestError, TimeoutError):
api.observations_for_key(
[ens.id for ens in all_ensembles], key_info.key
)

# Note: Does not test for fields
if not (str(key_info.key).endswith("H") or "H:" in str(key_info.key)):
with contextlib.suppress(RequestError, TimeoutError):
api.history_data(
key_info.key,
[e.id for e in all_ensembles],
)

t2 = time.time()
time_to_get_metadata = t1 - t0
time_to_cycle_through_responses = t2 - t1

# Local times were about 10% of the asserted times
assert time_to_get_metadata < 1
assert time_to_cycle_through_responses < 14


def test_plotter_on_all_snake_oil_responses_memory(api_and_snake_oil_storage):
api, _ = api_and_snake_oil_storage

with memray.Tracker("memray.bin", follow_fork=True, native_traces=True):
key_infos = api.all_data_type_keys()
all_ensembles = api.get_all_ensembles()
# Cycle through all ensembles and get all responses
for key_info in key_infos:
for ensemble in all_ensembles:
api.data_for_key(ensemble_id=ensemble.id, key=key_info.key)

if key_info.observations:
with contextlib.suppress(RequestError, TimeoutError):
api.observations_for_key(
[ens.id for ens in all_ensembles], key_info.key
)

# Note: Does not test for fields
if not (str(key_info.key).endswith("H") or "H:" in str(key_info.key)):
with contextlib.suppress(RequestError, TimeoutError):
api.history_data(
key_info.key,
[e.id for e in all_ensembles],
)

stats = memray._memray.compute_statistics("memray.bin")
os.remove("memray.bin")
total_memory_mb = stats.total_memory_allocated / (1024**2)
peak_memory_mb = stats.peak_memory_allocated / (1024**2)

# thresholds are set to about 1.5x local memory used
assert total_memory_mb < 5000
assert peak_memory_mb < 1500
Loading

0 comments on commit 08e98c5

Please sign in to comment.