From 6d9750dc0e3282c30dae1e55063379b47947e119 Mon Sep 17 00:00:00 2001 From: Eivind Jahren Date: Fri, 4 Oct 2024 09:11:35 +0200 Subject: [PATCH] Escape slashes in plotapi --- src/ert/dark_storage/common.py | 2 +- src/ert/dark_storage/endpoints/records.py | 4 + src/ert/gui/tools/plot/plot_api.py | 14 ++- src/ert/storage/local_ensemble.py | 2 +- tests/ert/unit_tests/dark_storage/conftest.py | 4 + .../dark_storage/test_dark_storage_state.py | 30 +++--- .../ert/unit_tests/gui/tools/plot/conftest.py | 6 +- .../gui/tools/plot/test_plot_api.py | 91 ++++++++++++++++++- .../unit_tests/storage/test_local_storage.py | 2 +- 9 files changed, 133 insertions(+), 22 deletions(-) diff --git a/src/ert/dark_storage/common.py b/src/ert/dark_storage/common.py index 18402e7d76e..678b0f41a1e 100644 --- a/src/ert/dark_storage/common.py +++ b/src/ert/dark_storage/common.py @@ -109,7 +109,7 @@ def data_for_key( "summary", tuple(ensemble.get_realization_list_with_responses("summary")) ) summary_keys = summary_data["response_key"].unique().to_list() - except (ValueError, KeyError): + except (ValueError, KeyError, polars.exceptions.ColumnNotFoundError): summary_data = polars.DataFrame() summary_keys = [] diff --git a/src/ert/dark_storage/endpoints/records.py b/src/ert/dark_storage/endpoints/records.py index bb172237715..70be40d4070 100644 --- a/src/ert/dark_storage/endpoints/records.py +++ b/src/ert/dark_storage/endpoints/records.py @@ -1,5 +1,6 @@ import io from typing import Any, Dict, List, Mapping, Union +from urllib.parse import unquote from uuid import UUID, uuid4 import numpy as np @@ -34,6 +35,7 @@ async def get_record_observations( ensemble_id: UUID, response_name: str, ) -> List[js.ObservationOut]: + response_name = unquote(response_name) ensemble = storage.get_ensemble(ensemble_id) obs_keys = get_observation_keys_for_response(ensemble, response_name) obss = get_observations_for_obs_keys(ensemble, obs_keys) @@ -74,6 +76,7 @@ async def get_ensemble_record( ensemble_id: UUID, accept: Annotated[Union[str, None], Header()] = None, ) -> Any: + name = unquote(name) dataframe = data_for_key(storage.get_ensemble(ensemble_id), name) media_type = accept if accept is not None else "text/csv" if media_type == "application/x-parquet": @@ -153,6 +156,7 @@ def get_ensemble_responses( def get_std_dev( *, storage: Storage = DEFAULT_STORAGE, ensemble_id: UUID, key: str, z: int ) -> Response: + key = unquote(key) ensemble = storage.get_ensemble(ensemble_id) try: da = ensemble.calculate_std_dev_for_parameter(key)["values"] diff --git a/src/ert/gui/tools/plot/plot_api.py b/src/ert/gui/tools/plot/plot_api.py index faf503b80a8..13a3d374f3c 100644 --- a/src/ert/gui/tools/plot/plot_api.py +++ b/src/ert/gui/tools/plot/plot_api.py @@ -4,6 +4,7 @@ from itertools import combinations as combi from json.decoder import JSONDecodeError from typing import Any, Dict, List, NamedTuple, Optional +from urllib.parse import quote import httpx import numpy as np @@ -42,6 +43,10 @@ def __init__(self) -> None: self._all_ensembles: Optional[List[EnsembleObject]] = None self._timeout = 120 + @staticmethod + def escape(s: str) -> str: + return quote(quote(s, safe="")) + def _get_ensemble_by_id(self, id: str) -> Optional[EnsembleObject]: for ensemble in self.get_all_ensembles(): if ensemble.id == id: @@ -162,8 +167,9 @@ def data_for_key(self, ensemble_id: str, key: str) -> pd.DataFrame: return pd.DataFrame() with StorageService.session() as client: + print(key) response = client.get( - f"/ensembles/{ensemble.id}/records/{key}", + f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}", headers={"accept": "application/x-parquet"}, timeout=self._timeout, ) @@ -195,8 +201,9 @@ def observations_for_key(self, ensemble_ids: List[str], key: str) -> pd.DataFram continue with StorageService.session() as client: + print(key) response = client.get( - f"/ensembles/{ensemble.id}/records/{key}/observations", + f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}/observations", timeout=self._timeout, ) self._check_response(response) @@ -260,8 +267,9 @@ def std_dev_for_parameter( return np.array([]) with StorageService.session() as client: + print(key) response = client.get( - f"/ensembles/{ensemble.id}/records/{key}/std_dev", + f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}/std_dev", params={"z": z}, timeout=self._timeout, ) diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index d9f65c6d915..831c26d1f0b 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -519,7 +519,7 @@ def get_summary_keyset(self) -> List[str]: ) return sorted(summary_data["response_key"].unique().to_list()) - except (ValueError, KeyError): + except (ValueError, KeyError, polars.ColumnNotFoundError): return [] def _load_single_dataset( diff --git a/tests/ert/unit_tests/dark_storage/conftest.py b/tests/ert/unit_tests/dark_storage/conftest.py index 7a040ac0bc5..f3b7bf0ffda 100644 --- a/tests/ert/unit_tests/dark_storage/conftest.py +++ b/tests/ert/unit_tests/dark_storage/conftest.py @@ -12,6 +12,7 @@ from ert.cli.main import run_cli from ert.dark_storage import enkf from ert.dark_storage.app import app +from ert.dark_storage.enkf import update_storage from ert.mode_definitions import ENSEMBLE_SMOOTHER_MODE @@ -52,6 +53,7 @@ def poly_example_tmp_dir(poly_example_tmp_dir_shared): def dark_storage_client(monkeypatch): with dark_storage_app_(monkeypatch) as dark_app: monkeypatch.setenv("ERT_STORAGE_ENS_PATH", "storage") + update_storage() with TestClient(dark_app) as client: yield client @@ -60,6 +62,7 @@ def dark_storage_client(monkeypatch): def dark_storage_client_snake_oil(monkeypatch): with dark_storage_app_(monkeypatch) as dark_app: monkeypatch.setenv("ERT_STORAGE_ENS_PATH", "storage/snake_oil/ensemble") + update_storage() with TestClient(dark_app) as client: yield client @@ -80,6 +83,7 @@ def reset_enkf(): def dark_storage_app_(monkeypatch): monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup") monkeypatch.setenv("ERT_STORAGE_ENS_PATH", "storage") + update_storage() yield app reset_enkf() diff --git a/tests/ert/unit_tests/dark_storage/test_dark_storage_state.py b/tests/ert/unit_tests/dark_storage/test_dark_storage_state.py index a81905f37e0..81399761394 100644 --- a/tests/ert/unit_tests/dark_storage/test_dark_storage_state.py +++ b/tests/ert/unit_tests/dark_storage/test_dark_storage_state.py @@ -1,11 +1,12 @@ import io import os +from urllib.parse import quote from uuid import UUID import hypothesis.strategies as st import pandas as pd import pytest -from hypothesis import assume, settings +from hypothesis import assume from hypothesis.stateful import rule from starlette.testclient import TestClient @@ -14,7 +15,10 @@ from tests.ert.unit_tests.storage.test_local_storage import StatefulStorageTest -@settings(max_examples=1000) +def escape(s): + return quote(quote(quote(s, safe=""))) + + class DarkStorageStateTest(StatefulStorageTest): def __init__(self): super().__init__() @@ -40,9 +44,11 @@ def get_experiments_through_client(self): @rule(model_experiment=StatefulStorageTest.experiments) def get_observations_through_client(self, model_experiment): response = self.client.get(f"/experiments/{model_experiment.uuid}/observations") - assert {r["name"] for r in response.json()} == set( - model_experiment.observations.keys() - ) + assert {r["name"] for r in response.json()} == { + key + for _, ds in model_experiment.observations.items() + for key in ds["observation_key"] + } @rule(model_experiment=StatefulStorageTest.experiments) def get_ensembles_through_client(self, model_experiment): @@ -55,14 +61,15 @@ def get_ensembles_through_client(self, model_experiment): def get_responses_through_client(self, model_ensemble): response = self.client.get(f"/ensembles/{model_ensemble.uuid}/responses") response_names = { - k for r in model_ensemble.response_values.values() for k in r["name"].values + k + for r in model_ensemble.response_values.values() + for k in r["response_key"] } assert set(response.json().keys()) == response_names @rule(model_ensemble=StatefulStorageTest.ensembles, data=st.data()) def get_response_csv_through_client(self, model_ensemble, data): assume(model_ensemble.response_values) - print("Hit it!") response_key, response_name = data.draw( st.sampled_from( [ @@ -75,16 +82,17 @@ def get_response_csv_through_client(self, model_ensemble, data): df = pd.read_parquet( io.BytesIO( self.client.get( - f"/ensembles/{model_ensemble.uuid}/records/{response_name}", + f"/ensembles/{model_ensemble.uuid}/records/{escape(response_name)}", headers={"accept": "application/x-parquet"}, ).content ) ) - assert set(df.columns) == set( - model_ensemble.response_values[response_key] + assert {dt[:10] for dt in df.columns} == { + str(dt)[:10] + for dt in model_ensemble.response_values[response_key] .sel(name=response_name)["time"] .values - ) + } def teardown(self): super().teardown() diff --git a/tests/ert/unit_tests/gui/tools/plot/conftest.py b/tests/ert/unit_tests/gui/tools/plot/conftest.py index c5eba03ab5a..dc59088248b 100644 --- a/tests/ert/unit_tests/gui/tools/plot/conftest.py +++ b/tests/ert/unit_tests/gui/tools/plot/conftest.py @@ -209,9 +209,9 @@ def mocked_requests_get(*args, **kwargs): records = { "/ensembles/ens_id_3/records/FOPR": summary_parquet_data, - "/ensembles/ens_id_3/records/BPR:1,3,8": summary_parquet_data, - "/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:BPR_138_PERSISTENCE": parameter_parquet_data, - "/ensembles/ens_id_3/records/SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE": parameter_parquet_data, + "/ensembles/ens_id_3/records/BPR%25253A1%25252C3%25252C8": summary_parquet_data, + "/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253ABPR_138_PERSISTENCE": parameter_parquet_data, + "/ensembles/ens_id_3/records/SNAKE_OIL_PARAM%25253AOP1_DIVERGENCE_SCALE": parameter_parquet_data, "/ensembles/ens_id_3/records/SNAKE_OIL_WPR_DIFF@199": gen_parquet_data, "/ensembles/ens_id_3/records/FOPRH": history_parquet_data, } diff --git a/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py b/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py index 75c031b1a1a..07b05ba1d70 100644 --- a/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py +++ b/tests/ert/unit_tests/gui/tools/plot/test_plot_api.py @@ -1,12 +1,39 @@ +from datetime import datetime +from textwrap import dedent +from urllib.parse import quote + import httpx import pandas as pd +import polars import pytest from pandas.testing import assert_frame_equal - -from ert.gui.tools.plot.plot_api import PlotApiKeyDefinition +from starlette.testclient import TestClient + +from ert.config import SummaryConfig +from ert.dark_storage.app import app +from ert.dark_storage.enkf import update_storage +from ert.gui.tools.plot.plot_api import PlotApi, PlotApiKeyDefinition +from ert.services import StorageService +from ert.storage import open_storage from tests.ert.unit_tests.gui.tools.plot.conftest import MockResponse +@pytest.fixture(autouse=True) +def use_testclient(monkeypatch): + client = TestClient(app) + monkeypatch.setattr(StorageService, "session", lambda: client) + + def test_escape(s: str) -> str: + """ + Workaround for issue with TestClient: + https://github.com/encode/starlette/issues/1060 + """ + print("TESTESCAPING") + return quote(quote(quote(s, safe=""))) + + PlotApi.escape = test_escape + + def test_key_def_structure(api): key_defs = api.all_data_type_keys() fopr = next(x for x in key_defs if x.key == "FOPR") @@ -146,3 +173,63 @@ def test_plot_api_request_errors(api): with pytest.raises(httpx.RequestError): api.data_for_key(ensemble.id, "should_not_be_there") + + +def test_plot_api_handles_urlescape(tmp_path, monkeypatch): + with open_storage(tmp_path / "storage", mode="w") as storage: + monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup") + monkeypatch.setenv("ERT_STORAGE_ENS_PATH", storage.path) + update_storage() + api = PlotApi() + key = "WBHP:46/3-7S" + date = datetime(year=2024, month=10, day=4) + experiment = storage.create_experiment( + parameters=[], + responses=[ + SummaryConfig( + name="summary", + input_files=["CASE.UNSMRY", "CASE.SMSPEC"], + keys=[key], + ) + ], + observations={ + "summary": polars.DataFrame( + { + "response_key": key, + "observation_key": "sumobs", + "time": polars.Series([date]).dt.cast_time_unit("ms"), + "observations": polars.Series([1.0], dtype=polars.Float32), + "std": polars.Series([1.0], dtype=polars.Float32), + } + ) + }, + ) + ensemble = experiment.create_ensemble(ensemble_size=1, name="ensemble") + assert api.data_for_key(str(ensemble.id), key).empty + df = polars.DataFrame( + { + "response_key": [key], + "time": [polars.Series([date]).dt.cast_time_unit("ms")], + "values": [polars.Series([1.0], dtype=polars.Float32)], + } + ) + df = df.explode("values", "time") + ensemble.save_response( + "summary", + df, + 0, + ) + assert api.data_for_key(str(ensemble.id), key).to_csv() == dedent( + """\ + Realization,2024-10-04 + 0,1.0 + """ + ) + assert api.observations_for_key([str(ensemble.id)], key).to_csv() == dedent( + """\ + ,0 + STD,1.0 + OBS,1.0 + key_index,2024-10-04 00:00:00 + """ + ) diff --git a/tests/ert/unit_tests/storage/test_local_storage.py b/tests/ert/unit_tests/storage/test_local_storage.py index ccf2abc54e9..df403641d9d 100644 --- a/tests/ert/unit_tests/storage/test_local_storage.py +++ b/tests/ert/unit_tests/storage/test_local_storage.py @@ -545,7 +545,7 @@ class Experiment: ensembles: Dict[UUID, Ensemble] = field(default_factory=dict) parameters: List[ParameterConfig] = field(default_factory=list) responses: List[ResponseConfig] = field(default_factory=list) - observations: Dict[str, xr.Dataset] = field(default_factory=dict) + observations: Dict[str, polars.DataFrame] = field(default_factory=dict) class StatefulStorageTest(RuleBasedStateMachine):