Skip to content

Commit

Permalink
Fix performance tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Dec 22, 2023
1 parent 9993224 commit 99cf2d0
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 82 deletions.
26 changes: 10 additions & 16 deletions src/ert/dark_storage/common.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,32 @@
import logging
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID

import pandas as pd

from ert.config import EnkfObservationImplementationType
from ert.libres_facade import LibresFacade
from ert.storage import EnsembleReader
from ert.storage import EnsembleReader, StorageReader

_logger = logging.getLogger(__name__)

def ensemble_parameter_names(storage: StorageReader, ensemble_id: UUID) -> List[str]:
return storage.get_ensemble(ensemble_id).get_gen_kw_keyset()

def ensemble_parameter_names(res: LibresFacade) -> List[str]:
return res.gen_kw_keys()


def ensemble_parameters(res: LibresFacade) -> List[Dict[str, Any]]:
def ensemble_parameters(
storage: StorageReader, ensemble_id: UUID
) -> List[Dict[str, Any]]:
return [
{"name": key, "userdata": {"data_origin": "GEN_KW"}, "labels": []}
for key in ensemble_parameter_names(res)
for key in ensemble_parameter_names(storage, ensemble_id)
]


def get_response_names(res: LibresFacade, ensemble: EnsembleReader) -> List[str]:
def get_response_names(ensemble: EnsembleReader) -> List[str]:
result = ensemble.get_summary_keyset()
result.extend(res.get_gen_data_keys().copy())
result.extend(ensemble.get_gen_data_keyset().copy())
return result


####


def data_for_key(
ensemble: EnsembleReader,
key: str,
Expand Down Expand Up @@ -70,9 +67,6 @@ def data_for_key(
return data


#####################


def observations_for_obs_keys(
res: LibresFacade, obs_keys: List[str]
) -> List[Dict[str, Any]]:
Expand Down
22 changes: 8 additions & 14 deletions src/ert/dark_storage/endpoints/ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@

from ert.dark_storage import json_schema as js
from ert.dark_storage.common import ensemble_parameter_names, get_response_names
from ert.dark_storage.enkf import LibresFacade, get_res, get_storage
from ert.dark_storage.enkf import get_storage
from ert.storage import StorageAccessor

router = APIRouter(tags=["ensemble"])
DEFAULT_LIBRESFACADE = Depends(get_res)
DEFAULT_STORAGE = Depends(get_storage)
DEFAULT_BODY = Body(...)


@router.post("/experiments/{experiment_id}/ensembles", response_model=js.EnsembleOut)
def post_ensemble(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
ens_in: js.EnsembleIn,
experiment_id: UUID,
) -> js.EnsembleOut:
Expand All @@ -27,28 +25,26 @@ def post_ensemble(
@router.get("/ensembles/{ensemble_id}", response_model=js.EnsembleOut)
def get_ensemble(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
db: StorageAccessor = DEFAULT_STORAGE,
storage: StorageAccessor = DEFAULT_STORAGE,
ensemble_id: UUID,
) -> js.EnsembleOut:
ens = db.get_ensemble(ensemble_id)
ensemble = storage.get_ensemble(ensemble_id)
return js.EnsembleOut(
id=ensemble_id,
children=[],
parent=None,
experiment_id=ens.experiment_id,
userdata={"name": ens.name},
size=ens.ensemble_size,
parameter_names=ensemble_parameter_names(res),
response_names=get_response_names(res, ens),
experiment_id=ensemble.experiment_id,
userdata={"name": ensemble.name},
size=ensemble.ensemble_size,
parameter_names=ensemble_parameter_names(storage, ensemble_id),
response_names=get_response_names(ensemble),
child_ensemble_ids=[],
)


@router.put("/ensembles/{ensemble_id}/userdata")
async def replace_ensemble_userdata(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
ensemble_id: UUID,
body: Any = DEFAULT_BODY,
) -> None:
Expand All @@ -58,7 +54,6 @@ async def replace_ensemble_userdata(
@router.patch("/ensembles/{ensemble_id}/userdata")
async def patch_ensemble_userdata(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
ensemble_id: UUID,
body: Any = DEFAULT_BODY,
) -> None:
Expand All @@ -68,7 +63,6 @@ async def patch_ensemble_userdata(
@router.get("/ensembles/{ensemble_id}/userdata", response_model=Mapping[str, Any])
async def get_ensemble_userdata(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
ensemble_id: UUID,
) -> Mapping[str, Any]:
raise NotImplementedError
6 changes: 1 addition & 5 deletions src/ert/dark_storage/endpoints/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
def post_observation(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
obs_in: js.ObservationIn,
experiment_id: UUID,
) -> js.ObservationOut:
Expand Down Expand Up @@ -48,15 +47,14 @@ def get_observations(
"/ensembles/{ensemble_id}/observations", response_model=List[js.ObservationOut]
)
def get_observations_with_transformation(
*, res: LibresFacade = DEFAULT_LIBRESFACADE, ensemble_id: UUID
*, ensemble_id: UUID
) -> List[js.ObservationOut]:
raise NotImplementedError


@router.put("/observations/{obs_id}/userdata")
async def replace_observation_userdata(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
obs_id: UUID,
body: Any = DEFAULT_BODY,
) -> None:
Expand All @@ -66,7 +64,6 @@ async def replace_observation_userdata(
@router.patch("/observations/{obs_id}/userdata")
async def patch_observation_userdata(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
obs_id: UUID,
body: Any = DEFAULT_BODY,
) -> None:
Expand All @@ -76,7 +73,6 @@ async def patch_observation_userdata(
@router.get("/observations/{obs_id}/userdata", response_model=Mapping[str, Any])
async def get_observation_userdata(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
obs_id: UUID,
) -> Mapping[str, Any]:
raise NotImplementedError
23 changes: 4 additions & 19 deletions src/ert/dark_storage/endpoints/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
@router.post("/ensembles/{ensemble_id}/records/{name}/file")
async def post_ensemble_record_file(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
name: str,
ensemble_id: UUID,
realization_index: Optional[int] = None,
Expand All @@ -42,7 +41,6 @@ async def post_ensemble_record_file(
@router.put("/ensembles/{ensemble_id}/records/{name}/blob")
async def add_block(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
name: str,
ensemble_id: UUID,
block_index: int,
Expand All @@ -55,7 +53,6 @@ async def add_block(
@router.post("/ensembles/{ensemble_id}/records/{name}/blob")
async def create_blob(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
name: str,
ensemble_id: UUID,
realization_index: Optional[int] = None,
Expand All @@ -66,7 +63,6 @@ async def create_blob(
@router.patch("/ensembles/{ensemble_id}/records/{name}/blob")
async def finalize_blob(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
name: str,
ensemble_id: UUID,
realization_index: Optional[int] = None,
Expand All @@ -79,7 +75,6 @@ async def finalize_blob(
)
async def post_ensemble_record_matrix(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
ensemble_id: UUID,
name: str,
prior: Optional[str] = None,
Expand All @@ -93,7 +88,6 @@ async def post_ensemble_record_matrix(
@router.put("/ensembles/{ensemble_id}/records/{name}/userdata")
async def replace_record_userdata(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
ensemble_id: UUID,
name: str,
realization_index: Optional[int] = None,
Expand All @@ -105,7 +99,6 @@ async def replace_record_userdata(
@router.patch("/ensembles/{ensemble_id}/records/{name}/userdata")
async def patch_record_userdata(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
ensemble_id: UUID,
name: str,
realization_index: Optional[int] = None,
Expand All @@ -119,7 +112,6 @@ async def patch_record_userdata(
)
async def get_record_userdata(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
ensemble_id: UUID,
name: str,
realization_index: Optional[int] = None,
Expand All @@ -130,7 +122,6 @@ async def get_record_userdata(
@router.post("/ensembles/{ensemble_id}/records/{name}/observations")
async def post_record_observations(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
ensemble_id: UUID,
name: str,
realization_index: Optional[int] = None,
Expand Down Expand Up @@ -214,7 +205,6 @@ async def get_ensemble_record(
@router.get("/ensembles/{ensemble_id}/records/{name}/labels", response_model=List[str])
async def get_record_labels(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
ensemble_id: UUID,
name: str,
) -> List[str]:
Expand All @@ -223,31 +213,26 @@ async def get_record_labels(

@router.get("/ensembles/{ensemble_id}/parameters", response_model=List[Dict[str, Any]])
async def get_ensemble_parameters(
*, res: LibresFacade = DEFAULT_LIBRESFACADE, ensemble_id: UUID
*, storage: StorageReader = DEFAULT_STORAGE, ensemble_id: UUID
) -> List[Dict[str, Any]]:
return ensemble_parameters(res)
return ensemble_parameters(storage, ensemble_id)


@router.get(
"/ensembles/{ensemble_id}/records", response_model=Mapping[str, js.RecordOut]
)
async def get_ensemble_records(
*, res: LibresFacade = DEFAULT_LIBRESFACADE, ensemble_id: UUID
) -> Mapping[str, js.RecordOut]:
async def get_ensemble_records(*, ensemble_id: UUID) -> Mapping[str, js.RecordOut]:
raise NotImplementedError


@router.get("/records/{record_id}", response_model=js.RecordOut)
async def get_record(
*, res: LibresFacade = DEFAULT_LIBRESFACADE, record_id: UUID
) -> js.RecordOut:
async def get_record(*, record_id: UUID) -> js.RecordOut:
raise NotImplementedError


@router.get("/records/{record_id}/data")
async def get_record_data(
*,
res: LibresFacade = DEFAULT_LIBRESFACADE,
record_id: UUID,
accept: Optional[str] = DEFAULT_HEADER,
) -> Any:
Expand Down
1 change: 0 additions & 1 deletion src/ert/libres_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def get_data_key_for_obs_key(self, observation_key: str) -> str:
else:
return obs.data_key

# duplicate in local_ensemble
def get_gen_data_keys(self) -> List[str]:
ensemble_config = self.config.ensemble_config
gen_data_keys = ensemble_config.get_keylist_gen_data()
Expand Down
11 changes: 5 additions & 6 deletions src/ert/shared/storage/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@

logger = logging.getLogger()


def create_observations(ert: LibresFacade) -> List[Dict[str, Dict[str, Any]]]:
keys = [i.observation_key for i in ert.get_observations()]
return observations_for_obs_keys(ert, keys)


_PRIOR_NAME_MAP = {
"NORMAL": "normal",
"LOGNORMAL": "lognormal",
Expand All @@ -31,6 +25,11 @@ def create_observations(ert: LibresFacade) -> List[Dict[str, Dict[str, Any]]]:
}


def create_observations(ert: LibresFacade) -> List[Dict[str, Dict[str, Any]]]:
keys = [i.observation_key for i in ert.get_observations()]
return observations_for_obs_keys(ert, keys)


def create_priors(ert: LibresFacade) -> Mapping[str, Dict[str, Union[str, float]]]:
priors = {}
for group, gen_kw_priors in ert.gen_kw_priors().items():
Expand Down
5 changes: 0 additions & 5 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,6 @@ def load_responses(
def get_active_realizations(self) -> List[int]:
return self.realization_list(RealizationStorageState.HAS_DATA)

### summary data

def load_all_summary_data(
self,
keys: Optional[List[str]] = None,
Expand Down Expand Up @@ -279,7 +277,6 @@ def gather_summary_data(
)
return data.unstack(level="Realization")

#### gen data
def _get_gen_data_config(self, key: str) -> GenDataConfig:
config = self.experiment.response_configuration[key]
assert isinstance(config, GenDataConfig)
Expand Down Expand Up @@ -326,8 +323,6 @@ def load_gen_data(
columns=realizations,
)

###### gen_kw

def get_gen_kw_keyset(self) -> List[str]:
gen_kw_keys = [
k
Expand Down
Loading

0 comments on commit 99cf2d0

Please sign in to comment.