diff --git a/tests/data/snake_oil_data.py b/tests/data/snake_oil_data.py index cde60e9b..8f8240a9 100644 --- a/tests/data/snake_oil_data.py +++ b/tests/data/snake_oil_data.py @@ -156,6 +156,8 @@ def content(self): "id": "FOPR", }, }, + "http://127.0.0.1:5000/ensembles/1/records/OP1_DIVERGENCE_SCALE/labels": [], + "http://127.0.0.1:5000/ensembles/1/records/BPR_138_PERSISTENCE/labels": [], "http://127.0.0.1:5000/ensembles/1/records/SNAKE_OIL_GPR_DIFF/observations?realization_index=0": [], "http://127.0.0.1:5000/ensembles/3/records/SNAKE_OIL_GPR_DIFF?realization_index=0": pd.DataFrame( [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], @@ -235,6 +237,16 @@ def to_parquet_helper(dataframe: pd.DataFrame) -> bytes: ).transpose() ) +ensembles_response[ + "http://127.0.0.1:5000/ensembles/1/records/SNAKE_OIL_GPR_DIFF" +] = to_parquet_helper( + pd.DataFrame( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + columns=["0"], + index=["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], + ).transpose() +) + ensembles_response[ "http://127.0.0.1:5000/ensembles/1/records/OP1_DIVERGENCE_SCALE" ] = to_parquet_helper( @@ -255,3 +267,69 @@ def to_parquet_helper(dataframe: pd.DataFrame) -> bytes: index=["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], ).transpose() ) + +ensembles_response.update( + { + "http://127.0.0.1:5000/ensembles/42": { + "data": { + "ensemble": { + "children": [], + "experiment": {"id": "exp1_id"}, + "parent": None, + "id": 1, + "timeCreated": "2020-04-29T09:36:26", + "size": 1, + "activeRealizations": [0], + "userdata": '{"name": "default"}', + } + } + }, + "http://127.0.0.1:5000/ensembles/42/parameters": [ + "test_parameter_1", + "test_parameter_2", + ], + "http://127.0.0.1:5000/ensembles/42/responses": { + "test_resposne": { + "name": "name_test_response", + "id": "test_response_id_1", + }, + }, + "http://127.0.0.1:5000/ensembles/42/records/test_parameter_1/labels": [], + "http://127.0.0.1:5000/ensembles/42/records/test_parameter_2/labels": [ + "a", + "b", + ], + "http://127.0.0.1:5000/ensembles/42/records/test_parameter_2/labels": [ + "a", + "b", + ], + } +) + +ensembles_response[ + "" "http://127.0.0.1:5000/ensembles/42/records/test_parameter_1?" +] = to_parquet_helper( + pd.DataFrame( + [0.1, 1.1, 2.1], + columns=["0"], + index=["0", "1", "2"], + ).transpose() +) +ensembles_response[ + "http://127.0.0.1:5000/ensembles/42/records/test_parameter_2?label=a" +] = to_parquet_helper( + pd.DataFrame( + [0.01, 1.01, 2.01], + columns=["a"], + index=["0", "1", "2"], + ).transpose() +) +ensembles_response[ + "http://127.0.0.1:5000/ensembles/42/records/test_parameter_2?label=b" +] = to_parquet_helper( + pd.DataFrame( + [0.02, 1.02, 2.02], + columns=["b"], + index=["0", "1", "2"], + ).transpose() +) diff --git a/tests/models/test_ensemble_model.py b/tests/models/test_ensemble_model.py index 803a3d4e..ba328516 100644 --- a/tests/models/test_ensemble_model.py +++ b/tests/models/test_ensemble_model.py @@ -7,3 +7,44 @@ def test_ensemble_model(mock_data): assert ens_model.children[0]._name == "default_smoother_update" assert ens_model._name == "default" assert len(ens_model.responses) == 1 + + +def test_ensemble_model_labled_parameters(mock_data): + ens_id = 42 + ens_model = EnsembleModel(ensemble_id=ens_id, project_id=None) + assert ens_model._name == "default" + assert len(ens_model.parameters) == 3 + for param_name, parameter in ens_model.parameters.items(): + name, label = ( + param_name.split("::", maxsplit=1) + if "::" in param_name + else [param_name, None] + ) + expected_lables = ens_model._data_loader.get_record_labels(ens_id, name) + if label is not None: + assert label in expected_lables + + +def test_ensemble_model_parameter_data(mock_data): + ens_id = 42 + ens_model = EnsembleModel(ensemble_id=ens_id, project_id=None) + parameters = ens_model.parameters + assert len(parameters) == 3 + + # Parameter no lables: + expected_lables = ens_model._data_loader.get_record_labels( + ens_id, "test_parameter_1" + ) + assert expected_lables == [] + data = parameters["test_parameter_1"].data_df().values + assert data.flatten().tolist() == [0.1, 1.1, 2.1] + + # Parameter with lables: + expected_lables = ens_model._data_loader.get_record_labels( + ens_id, "test_parameter_2" + ) + assert expected_lables == ["a", "b"] + data = parameters["test_parameter_2::a"].data_df()["a"].values.tolist() + assert data == [0.01, 1.01, 2.01] + data = parameters["test_parameter_2::b"].data_df()["b"].values.tolist() + assert data == [0.02, 1.02, 2.02] diff --git a/webviz_ert/data_loader/__init__.py b/webviz_ert/data_loader/__init__.py index 5b03b75e..a3d9c30e 100644 --- a/webviz_ert/data_loader/__init__.py +++ b/webviz_ert/data_loader/__init__.py @@ -1,5 +1,5 @@ import json -from typing import Any, Mapping, Optional, List, MutableMapping, Tuple +from typing import Any, Mapping, Optional, List, MutableMapping, Tuple, Dict from collections import defaultdict from pprint import pformat import requests @@ -113,7 +113,7 @@ class DataLoader: _instances: MutableMapping[ServerIdentifier, "DataLoader"] = {} baseurl: str - token: str + token: Optional[str] _graphql_cache: MutableMapping[str, MutableMapping[dict, Any]] def __new__(cls, baseurl: str, token: Optional[str] = None) -> "DataLoader": @@ -195,45 +195,62 @@ def get_ensemble_userdata(self, ensemble_id: str) -> dict: def get_ensemble_parameters(self, ensemble_id: str) -> list: return self._get(url=f"ensembles/{ensemble_id}/parameters").json() + def get_record_labels(self, ensemble_id: str, name: str) -> list: + return self._get(url=f"ensembles/{ensemble_id}/records/{name}/labels").json() + def get_experiment_priors(self, experiment_id: str) -> dict: return json.loads( self._query(GET_PRIORS, id=experiment_id)["experiment"]["priors"] ) def get_ensemble_parameter_data( - self, ensemble_id: str, parameter_name: str + self, + ensemble_id: str, + parameter_name: str, ) -> pd.DataFrame: - resp = self._get( - url=f"ensembles/{ensemble_id}/records/{parameter_name}", - headers={"accept": "application/x-parquet"}, - ) - stream = io.BytesIO(resp.content) - df = pd.read_parquet(stream) - return df + try: + if "::" in parameter_name: + name, label = parameter_name.split("::", 1) + params = {"label": label} + else: + name = parameter_name + params = {} + + resp = self._get( + url=f"ensembles/{ensemble_id}/records/{name}", + headers={"accept": "application/x-parquet"}, + params=params, + ) + stream = io.BytesIO(resp.content) + df = pd.read_parquet(stream).transpose() + return df + except DataLoaderException as e: + logger.error(e) + return pd.DataFrame() def get_ensemble_record_data( - self, ensemble_id: str, record_name: str, active_realizations: List[int] + self, + ensemble_id: str, + record_name: str, ) -> pd.DataFrame: - dfs = [] - for rel_idx in active_realizations: - try: - resp = self._get( - url=f"ensembles/{ensemble_id}/records/{record_name}", - headers={"accept": "application/x-parquet"}, - params={"realization_index": rel_idx}, - ) - stream = io.BytesIO(resp.content) - df = pd.read_parquet(stream).transpose() - df.columns = [rel_idx] - dfs.append(df) - - except DataLoaderException as e: - logger.error(e) - - if dfs == []: + try: + resp = self._get( + url=f"ensembles/{ensemble_id}/records/{record_name}", + headers={"accept": "application/x-parquet"}, + ) + stream = io.BytesIO(resp.content) + df = pd.read_parquet(stream).transpose() + + except DataLoaderException as e: + logger.error(e) return pd.DataFrame() - return pd.concat(dfs, axis=1) + try: + df.index = df.index.astype(int) + except TypeError: + pass + df = df.sort_index() + return df def get_ensemble_record_observations( self, ensemble_id: str, record_name: str diff --git a/webviz_ert/models/ensemble_model.py b/webviz_ert/models/ensemble_model.py index 3e850bcd..82d40876 100644 --- a/webviz_ert/models/ensemble_model.py +++ b/webviz_ert/models/ensemble_model.py @@ -1,15 +1,16 @@ import json import pandas as pd from typing import Mapping, List, Dict, Union, Any, Optional -from webviz_ert.data_loader import ( - get_data_loader, -) +from webviz_ert.data_loader import get_data_loader, DataLoaderException from webviz_ert.models import Response, PriorModel, ParametersModel def _create_parameter_models( - parameters_names: list, priors: dict, ensemble_id: str, project_id: str + parameters_names: list, + priors: dict, + ensemble_id: str, + project_id: str, ) -> Optional[Mapping[str, ParametersModel]]: parameters = {} for param in parameters_names: @@ -98,7 +99,14 @@ def parameters( self, ) -> Optional[Mapping[str, ParametersModel]]: if not self._parameters: - parameter_names = self._data_loader.get_ensemble_parameters(self._id) + parameter_names = [] + for param_name in self._data_loader.get_ensemble_parameters(self._id): + labels = self._data_loader.get_record_labels(self._id, param_name) + if len(labels) > 0: + for label in labels: + parameter_names.append(f"{param_name}::{label}") + else: + parameter_names.append(param_name) parameter_priors = ( self._data_loader.get_experiment_priors(self._experiment_id) if not self._parent diff --git a/webviz_ert/models/parameter_model.py b/webviz_ert/models/parameter_model.py index a7d7c8b1..24545410 100644 --- a/webviz_ert/models/parameter_model.py +++ b/webviz_ert/models/parameter_model.py @@ -30,10 +30,10 @@ def __init__(self, **kwargs: Any): def data_df(self) -> pd.DataFrame: if self._data_df.empty: _data_df = self._data_loader.get_ensemble_parameter_data( - ensemble_id=self._ensemble_id, parameter_name=self.key + ensemble_id=self._ensemble_id, + parameter_name=self.key, ) if _data_df is not None: - _data_df = _data_df.transpose() _data_df.index.name = self.key self._data_df = _data_df return self._data_df diff --git a/webviz_ert/models/response.py b/webviz_ert/models/response.py index 85ff2886..82505325 100644 --- a/webviz_ert/models/response.py +++ b/webviz_ert/models/response.py @@ -40,7 +40,7 @@ def axis(self) -> Optional[List[Union[int, str, datetime.datetime]]]: def data(self) -> pd.DataFrame: if self._data is None: self._data = self._data_loader.get_ensemble_record_data( - self._ensemble_id, self.name, self._active_realizations + self._ensemble_id, self.name ) return self._data