Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
larsevj committed Oct 8, 2024
1 parent 7e0671a commit 2ce1d41
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
15 changes: 14 additions & 1 deletion src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, List, Optional

import pandas as pd
from pandas.api.types import is_integer_dtype

from ert.config.gen_kw_config import GenKwConfig, TransformFunctionDefinition

Expand All @@ -27,6 +28,8 @@ class DesignMatrix:
xls_filename: Path
design_sheet: str
default_sheet: str
num_realizations: Optional[int] = None
active_realizations: Optional[list[int]] = None
design_matrix_df: Optional[pd.DataFrame] = None
parameter_configuration: Optional[dict[str, ParameterConfig]] = None

Expand Down Expand Up @@ -80,7 +83,13 @@ def read_design_matrix(
self.xls_filename, self.design_sheet
)
if "REAL" in design_matrix_df.columns:
design_matrix_df = design_matrix_df.set_index("REAL", drop=True)
if not is_integer_dtype(design_matrix_df.dtypes["REAL"]) or any(
design_matrix_df["REAL"] < 0
):
raise ValueError("REAL column must only contain positive integers")
design_matrix_df = design_matrix_df.set_index(
"REAL", drop=True, verify_integrity=True
)
try:
DesignMatrix._validate_design_matrix_header(design_matrix_df)
except ValueError as err:
Expand Down Expand Up @@ -119,6 +128,10 @@ def read_design_matrix(
design_matrix_df.columns = pd.MultiIndex.from_product(
[[DESIGN_MATRIX_GROUP], design_matrix_df.columns]
)
reals = design_matrix_df.index.tolist()
self.num_realizations = len(reals)
self.active_realizations = [x in reals for x in range(max(reals))]

self.design_matrix_df = design_matrix_df
self.parameter_configuration = parameter_configuration

Expand Down
70 changes: 69 additions & 1 deletion tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,80 @@
import numpy as np
import pandas as pd
import pytest

from ert.config import DesignMatrix


def test_reading_design_matrix(tmp_path):
design_path = tmp_path / "design_matrix.xlsx"
design_matrix_df = pd.DataFrame(
{"REAL": [0, 1, 2], "a": [1, 2, 3], "b": [0, 2, 0], "c": [3, 1, 3]}
{
"REAL": [0, 1, 2],
"a": [1, 2, 3],
"b": [0, 2, 0],
"c": [3, 1, 3],
}
)
default_sheet_df = pd.DataFrame([["one", 1], ["b", 4], ["d", 6]])
with pd.ExcelWriter(design_path) as xl_write:
design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01")
default_sheet_df.to_excel(
xl_write, index=False, sheet_name="DefaultValues", header=False
)
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
design_matrix.read_design_matrix()


@pytest.mark.parametrize(
"real_column, error_msg",
[
pytest.param([0, 1, 1], "Index has duplicate keys", id="duplicate entries"),
pytest.param(
[0, 1.1, 2],
"REAL column must only contain positive integers",
id="invalid float values",
),
pytest.param(
[0, "a", 10],
"REAL column must only contain positive integers",
id="invalid types",
),
],
)
def test_reading_design_matrix_validate_reals(tmp_path, real_column, error_msg):
design_path = tmp_path / "design_matrix.xlsx"
design_matrix_df = pd.DataFrame(
{
"REAL": real_column,
"a": [1, 2, 3],
"b": [0, 2, 0],
"c": [3, 1, 3],
}
)
default_sheet_df = pd.DataFrame()
with pd.ExcelWriter(design_path) as xl_write:
design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01")
default_sheet_df.to_excel(
xl_write, index=False, sheet_name="DefaultValues", header=False
)
design_matrix = DesignMatrix(design_path, "DesignSheet01", "DefaultValues")
with pytest.raises(ValueError, match=error_msg):
design_matrix.read_design_matrix()


def test_reading_design_matrix_duplicate_columns(tmp_path):
design_path = tmp_path / "design_matrix.xlsx"
design_matrix_df = pd.DataFrame(
{
"REAL": [0, 1, -4],
"a": [1, 2, 3],
"b": [0, 2, 0],
"c": [3, 1, 3],
"0": ["a", 2, "c"],
}
)
design_matrix_df = pd.DataFrame(
np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), columns=["a", "b", "a"]
)
default_sheet_df = pd.DataFrame([["one", 1], ["b", 4], ["d", 6]])
with pd.ExcelWriter(design_path) as xl_write:
Expand Down

0 comments on commit 2ce1d41

Please sign in to comment.