-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
193 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
|
||
import pandas as pd | ||
|
||
from ert.config.gen_kw_config import GenKwConfig | ||
|
||
if TYPE_CHECKING: | ||
from ert.config import ( | ||
ErtConfig, | ||
) | ||
|
||
DESIGN_MATRIX_GROUP = "DESIGN_MATRIX" | ||
|
||
|
||
def read_design_matrix( | ||
ert_config: ErtConfig, | ||
xlsfilename: Path | str, | ||
designsheetname: str = "DesignSheet01", | ||
defaultssheetname: str = "DefaultValues", | ||
) -> pd.DataFrame: | ||
""" | ||
Reads out all file content from different files and create dataframes | ||
""" | ||
design_matrix_sheet = _read_excel(xlsfilename, designsheetname) | ||
if "REAL" in design_matrix_sheet.columns: | ||
design_matrix_sheet.set_index(design_matrix_sheet["REAL"]) | ||
del design_matrix_sheet["REAL"] | ||
try: | ||
_validate_design_matrix_header(design_matrix_sheet) | ||
except ValueError as err: | ||
raise ValueError(f"Design matrix not valid, error: {err!s}") from err | ||
|
||
# Todo: Check for invalid realizations, drop them maybe? | ||
|
||
if designsheetname == defaultssheetname: | ||
raise ValueError("Design-sheet and defaults-sheet can not be the same") | ||
|
||
# This should probably handle/(fill in) missing values in design_matrix_sheet as well | ||
defaults = _read_defaultssheet(xlsfilename, defaultssheetname) | ||
for k, v in defaults.items(): | ||
if k not in design_matrix_sheet.columns: | ||
design_matrix_sheet[k] = v | ||
|
||
# ignoring errors here is deprecated in pandas, should find another solution | ||
# design_matrix_sheet = design_matrix_sheet.apply(pd.to_numeric, errors="ignore") | ||
|
||
parameter_groups = defaultdict(list) | ||
parameter_map = [] | ||
all_genkw_configs = [ | ||
param_group | ||
for param_group in ert_config.ensemble_config.parameter_configuration | ||
if isinstance(param_group, GenKwConfig) | ||
] | ||
errors = {} | ||
for param in design_matrix_sheet.columns: | ||
par_gp = [] | ||
for param_group in all_genkw_configs: | ||
if param in param_group: | ||
par_gp.append(param_group.name) | ||
|
||
if not par_gp: | ||
parameter_name = "DESIGN_MATRIX" | ||
parameter_groups[parameter_name].append(param) | ||
parameter_map.append((parameter_name, param)) | ||
elif len(par_gp) == 1: | ||
parameter_name = par_gp[0] | ||
parameter_groups[parameter_name].append(param) | ||
parameter_map.append((parameter_name, param)) | ||
else: | ||
errors[param] = par_gp | ||
|
||
if errors: | ||
msg = "" | ||
for key, value in errors.items(): | ||
msg += ( | ||
f"The following parameter '{key}' was found in multiple" | ||
f" GenKw parameters groups: {value}." | ||
) | ||
raise ValueError(msg) | ||
design_matrix_sheet.columns = pd.MultiIndex.from_tuples(parameter_map) | ||
return design_matrix_sheet | ||
|
||
|
||
def _read_excel( | ||
file_name: Path | str, | ||
sheet_name: str, | ||
usecols: int | list[int] | None = None, | ||
header: int | None = 0, | ||
) -> pd.DataFrame: | ||
""" | ||
Make dataframe from excel file | ||
:return: Dataframe | ||
:raises: OsError if file not found | ||
:raises: ValueError if file not loaded correctly | ||
""" | ||
dframe: pd.DataFrame = pd.read_excel( | ||
file_name, | ||
sheet_name, | ||
usecols=usecols, | ||
header=header, | ||
) | ||
return dframe.dropna(axis=1, how="all") | ||
|
||
|
||
def _validate_design_matrix_header(design_matrix: pd.DataFrame) -> None: | ||
""" | ||
Validate header in user inputted design matrix | ||
:raises: ValueError if design matrix contains empty headers | ||
""" | ||
if design_matrix.empty: | ||
return | ||
try: | ||
unnamed = design_matrix.loc[:, design_matrix.columns.str.contains("^Unnamed")] | ||
except ValueError as err: | ||
# We catch because int/floats as column headers | ||
# in xlsx gets read as int/float and is not valid to index by. | ||
raise ValueError( | ||
f"Invalid value in design matrix header, error: {err !s}" | ||
) from err | ||
column_indexes = [int(x.split(":")[1]) for x in unnamed.columns.to_numpy()] | ||
if len(column_indexes) > 0: | ||
raise ValueError(f"Column headers not present in column {column_indexes}") | ||
|
||
|
||
def _read_defaultssheet( | ||
xlsfilename: Path | str, defaultssheetname: str | ||
) -> dict[str, str]: | ||
""" | ||
Construct a dataframe of keys and values to be used as defaults from the | ||
first two columns in a spreadsheet. | ||
Returns a dict of default values | ||
:raises: ValueError if defaults sheet is non-empty but non-parsable | ||
""" | ||
if defaultssheetname: | ||
default_df = _read_excel( | ||
xlsfilename, defaultssheetname, usecols=[0, 1], header=None | ||
) | ||
if default_df.empty: | ||
return {} | ||
if len(default_df.columns) < 2: | ||
raise ValueError("Defaults sheet must have at least two columns") | ||
# Look for initial or trailing whitespace in parameter names. This | ||
# is disallowed as it can create user confusion and has no use-case. | ||
for paramname in default_df.loc[:, 0]: | ||
if paramname != paramname.strip(): | ||
raise ValueError( | ||
f'Parameter name "{paramname}" in default values contains ' | ||
"initial or trailing whitespace." | ||
) | ||
|
||
else: | ||
return {} | ||
|
||
default_df = default_df.rename(columns={0: "keys", 1: "defaults"}) | ||
defaults = {} | ||
for _, row in default_df.iterrows(): | ||
defaults[row["keys"]] = row["defaults"] | ||
return defaults |
25 changes: 25 additions & 0 deletions
25
tests/ert/unit_tests/sensitivity_analysis/test_design_matrix.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pandas as pd | ||
import pytest | ||
|
||
from ert.config import ErtConfig | ||
from ert.sensitivity_analysis.design_matrix import ( | ||
read_design_matrix, | ||
) | ||
|
||
|
||
@pytest.mark.usefixtures("copy_poly_case") | ||
def test_reading_design_matrix(copy_poly_case): | ||
design_matrix_df = pd.DataFrame( | ||
{"REAL": [0, 1, 2], "a": [1, 2, 3], "b": [0, 2, 0], "c": [3, 1, 3]} | ||
) | ||
default_sheet_df = pd.DataFrame() | ||
with pd.ExcelWriter("design_matrix.xlsx") as xl_write: | ||
design_matrix_df.to_excel(xl_write, index=False, sheet_name="DesignSheet01") | ||
default_sheet_df.to_excel(xl_write, index=False, sheet_name="DefaultValues") | ||
|
||
with open("poly.ert", "a", encoding="utf-8") as fhandle: | ||
fhandle.write( | ||
"DESIGN_MATRIX design_matrix.xlsx DESIGN_SHEET:DesignSheet01 DEFAULT_SHEET:DefaultValues" | ||
) | ||
ert_config = ErtConfig.from_file("poly.ert") | ||
_design_frame = read_design_matrix(ert_config, "design_matrix.xlsx") |