Skip to content

Commit

Permalink
Remove unused index_lists
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen authored and yngve-sk committed Sep 16, 2024
1 parent dbf2ae6 commit 47c7f5f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 56 deletions.
42 changes: 2 additions & 40 deletions src/ert/data/_measured_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@

from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional

import numpy as np
import pandas as pd

if TYPE_CHECKING:
import numpy.typing as npt

from ert.storage import Ensemble


Expand All @@ -25,21 +22,13 @@ class ResponseError(Exception):


class MeasuredData:
def __init__(
self,
ensemble: Ensemble,
keys: Optional[List[str]] = None,
index_lists: Optional[List[List[Union[int, datetime]]]] = None,
):
def __init__(self, ensemble: Ensemble, keys: Optional[List[str]] = None):
if keys is None:
keys = sorted(ensemble.experiment.observations.keys())
if not keys:
raise ObservationError("No observation keys provided")
if index_lists is not None and len(index_lists) != len(keys):
raise ValueError("index list must be same length as observations keys")

self._set_data(self._get_data(ensemble, keys))
self._set_data(self.filter_on_column_index(keys, index_lists))

@property
def data(self) -> pd.DataFrame:
Expand Down Expand Up @@ -167,33 +156,6 @@ def _get_data(

return pd.concat(measured_data, axis=1)

def filter_on_column_index(
self,
obs_keys: List[str],
index_lists: Optional[List[List[Union[int, datetime]]]] = None,
) -> pd.DataFrame:
if index_lists is None or all(index_list is None for index_list in index_lists):
return self.data
names = self.data.columns.get_level_values(0)
data_index = self.data.columns.get_level_values("key_index")
cond = self._create_condition(names, data_index, obs_keys, index_lists)
return self.data.iloc[:, cond]

@staticmethod
def _create_condition(
names: pd.Index,
data_index: pd.Index,
obs_keys: List[str],
index_lists: List[List[Union[int, datetime]]],
) -> "npt.NDArray[np.bool_]":
conditions = []
for obs_key, index_list in zip(obs_keys, index_lists):
if index_list is not None:
index_cond = [data_index == index for index in index_list]
index_cond = np.logical_or.reduce(index_cond)
conditions.append(np.logical_and(index_cond, (names == obs_key)))
return np.logical_or.reduce(conditions)


class ObservationError(Exception):
pass
16 changes: 0 additions & 16 deletions tests/unit_tests/data/test_integration_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,6 @@ def test_gen_obs_and_summary(create_measured_data):
]


def test_gen_obs_and_summary_index_range(create_measured_data):
df = create_measured_data(["WPR_DIFF_1", "FOPR"], [[800], [datetime(2010, 4, 20)]])
df.remove_inactive_observations()

assert df.data.columns.get_level_values(0).to_list() == [
"WPR_DIFF_1",
"FOPR",
]
assert df.data.columns.get_level_values("data_index").to_list() == [
800,
10,
]
assert df.data.loc["OBS"].values == pytest.approx([0.1, 0.23281], abs=0.00001)
assert df.data.loc["STD"].values == pytest.approx([0.2, 0.1])


@pytest.mark.parametrize(
"obs_key, expected_msg",
[
Expand Down

0 comments on commit 47c7f5f

Please sign in to comment.