Skip to content

Commit

Permalink
Polars&parquet for responses and scaling factors
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen authored and yngve-sk committed Oct 4, 2024
1 parent 6d5c0a2 commit 3e1b012
Show file tree
Hide file tree
Showing 53 changed files with 2,185 additions and 1,766 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies = [
"packaging",
"pandas",
"pluggy>=1.3.0",
"polars",
"psutil",
"pyarrow", # extra dependency for pandas (parquet)
"pydantic > 2",
Expand Down
123 changes: 71 additions & 52 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import iterative_ensemble_smoother as ies
import numpy as np
import pandas as pd
import polars
import psutil
from iterative_ensemble_smoother.experimental import (
AdaptiveESMDA,
Expand Down Expand Up @@ -153,56 +153,75 @@ def _get_observations_and_responses(
observation_values = []
observation_errors = []
indexes = []
observations = ensemble.experiment.observations
for obs in selected_observations:
observation = observations[obs]
group = observation.attrs["response"]
all_responses = ensemble.load_responses(group, tuple(iens_active_index))
if "time" in observation.coords:
all_responses = all_responses.reindex(
time=observation.time,
method="nearest",
observations_by_type = ensemble.experiment.observations
for (
response_type,
response_cls,
) in ensemble.experiment.response_configuration.items():
if response_type not in observations_by_type:
continue

observations_for_type = observations_by_type[response_type].filter(
polars.col("observation_key").is_in(list(selected_observations))
)
responses_for_type = ensemble.load_responses(
response_type, realizations=tuple(iens_active_index)
)

# Note that if there are duplicate entries for one
# response at one index, they are aggregated together
# with "mean" by default
pivoted = responses_for_type.pivot(
on="realization",
index=["response_key", *response_cls.primary_key],
aggregate_function="mean",
)

# Note2reviewer:
# We need to either assume that if there is a time column
# we will approx-join that, or we could specify in response configs
# that there is a column that requires an approx "asof" join.
# Suggest we simplify and assume that there is always only
# one "time" column, which we will reindex towards the response dataset
# with a given resolution
if "time" in pivoted:
joined = observations_for_type.join_asof(
pivoted,
by=["response_key", *response_cls.primary_key],
on="time",
tolerance="1s",
)
try:
observations_and_responses = observation.merge(all_responses, join="left")
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched index for: "
f"Observation: {obs} attached to response: {group}"
) from e

observation_keys.append([obs] * observations_and_responses["observations"].size)

if group == "summary":
indexes.append(
[
np.datetime_as_string(e, unit="s")
for e in observations_and_responses["time"].data
]
)
else:
indexes.append(
[
f"{e[0]}, {e[1]}"
for e in zip(
list(observations_and_responses["report_step"].data)
* len(observations_and_responses["index"].data),
observations_and_responses["index"].data,
)
]
joined = observations_for_type.join(
pivoted,
how="left",
on=["response_key", *response_cls.primary_key],
)

observation_values.append(
observations_and_responses["observations"].data.ravel()
)
observation_errors.append(observations_and_responses["std"].data.ravel())
joined = joined.sort(by="observation_key")

index_1d = joined.with_columns(
polars.concat_str(response_cls.primary_key, separator=", ").alias("index")
)["index"].to_numpy()

obs_keys_1d = joined["observation_key"].to_numpy()
obs_values_1d = joined["observations"].to_numpy()
obs_errors_1d = joined["std"].to_numpy()

# 4 columns are always there:
# [ response_key, observation_key, observations, std ]
# + one column per "primary key" column
num_non_response_value_columns = 4 + len(response_cls.primary_key)
responses = joined.select(
joined.columns[num_non_response_value_columns:]
).to_numpy()

filtered_responses.append(responses)
observation_keys.append(obs_keys_1d)
observation_values.append(obs_values_1d)
observation_errors.append(obs_errors_1d)
indexes.append(index_1d)

filtered_responses.append(
observations_and_responses["values"]
.transpose(..., "realization")
.values.reshape((-1, len(observations_and_responses.realization)))
)
ensemble.load_responses.cache_clear()
return (
np.concatenate(filtered_responses),
Expand Down Expand Up @@ -288,12 +307,14 @@ def _load_observations_and_responses(
scaling[obs_group_mask] *= scaling_factors

scaling_factors_dfs.append(
pd.DataFrame(
data={
polars.DataFrame(
{
"input_group": [", ".join(input_group)] * len(scaling_factors),
"index": indexes[obs_group_mask],
"obs_key": obs_keys[obs_group_mask],
"scaling_factor": scaling_factors,
"scaling_factor": polars.Series(
scaling_factors, dtype=polars.Float32
),
}
)
)
Expand Down Expand Up @@ -322,10 +343,8 @@ def _load_observations_and_responses(
)
)

scaling_factors_df = pd.concat(scaling_factors_dfs).set_index(
["input_group", "obs_key", "index"], verify_integrity=True
)
ensemble.save_observation_scaling_factors(scaling_factors_df.to_xarray())
scaling_factors_df = polars.concat(scaling_factors_dfs)
ensemble.save_observation_scaling_factors(scaling_factors_df)

# Recompute with updated scales
scaled_errors = errors * scaling
Expand Down
26 changes: 24 additions & 2 deletions src/ert/config/ert_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
overload,
)

import xarray as xr
import polars
from pydantic import ValidationError as PydanticValidationError
from typing_extensions import Self

Expand Down Expand Up @@ -112,6 +112,28 @@ class ErtConfig:
Tuple[str, Union[HistoryValues, SummaryValues, GenObsValues]]
] = field(default_factory=list)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ErtConfig):
return False

for attr in vars(self):
if attr == "observations":
if self.observations.keys() != other.observations.keys():
return False

if not all(
self.observations[k].equals(other.observations[k])
for k in self.observations
):
return False

continue

if getattr(self, attr) != getattr(other, attr):
return False

return True

def __post_init__(self) -> None:
self.config_path = (
path.dirname(path.abspath(self.user_config_file))
Expand All @@ -120,7 +142,7 @@ def __post_init__(self) -> None:
)
self.enkf_obs: EnkfObs = self._create_observations(self.observation_config)

self.observations: Dict[str, xr.Dataset] = self.enkf_obs.datasets
self.observations: Dict[str, polars.DataFrame] = self.enkf_obs.datasets

@staticmethod
def with_plugins(
Expand Down
34 changes: 20 additions & 14 deletions src/ert/config/gen_data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Optional, Tuple

import numpy as np
import xarray as xr
import polars
from typing_extensions import Self

from ert.validation import rangestring_to_list
Expand Down Expand Up @@ -107,21 +107,23 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]:
report_steps_list=report_steps,
)

def read_from_file(self, run_path: str, _: int) -> xr.Dataset:
def _read_file(filename: Path, report_step: int) -> xr.Dataset:
def read_from_file(self, run_path: str, _: int) -> polars.DataFrame:
def _read_file(filename: Path, report_step: int) -> polars.DataFrame:
if not filename.exists():
raise ValueError(f"Missing output file: {filename}")
data = np.loadtxt(_run_path / filename, ndmin=1)
active_information_file = _run_path / (str(filename) + "_active")
if active_information_file.exists():
active_list = np.loadtxt(active_information_file)
data[active_list == 0] = np.nan
return xr.Dataset(
{"values": (["report_step", "index"], [data])},
coords={
"index": np.arange(len(data)),
"report_step": [report_step],
},
return polars.DataFrame(
{
"report_step": polars.Series(
np.full(len(data), report_step), dtype=polars.UInt16
),
"index": polars.Series(np.arange(len(data)), dtype=polars.UInt16),
"values": polars.Series(data, dtype=polars.Float32),
}
)

errors = []
Expand Down Expand Up @@ -150,16 +152,16 @@ def _read_file(filename: Path, report_step: int) -> xr.Dataset:
except ValueError as err:
errors.append(str(err))

ds_all_report_steps = xr.concat(
datasets_per_report_step, dim="report_step"
).expand_dims(name=[name])
ds_all_report_steps = polars.concat(datasets_per_report_step)
ds_all_report_steps.insert_column(
0, polars.Series("response_key", [name] * len(ds_all_report_steps))
)
datasets_per_name.append(ds_all_report_steps)

if errors:
raise ValueError(f"Error reading GEN_DATA: {self.name}, errors: {errors}")

combined = xr.concat(datasets_per_name, dim="name")
combined.attrs["response"] = "gen_data"
combined = polars.concat(datasets_per_name)
return combined

def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]]]:
Expand All @@ -173,5 +175,9 @@ def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]]
def response_type(self) -> str:
return "gen_data"

@property
def primary_key(self) -> List[str]:
return ["report_step", "index"]


responses_index.add_response_type(GenDataConfig)
53 changes: 35 additions & 18 deletions src/ert/config/observation_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Iterable, List, Union

import xarray as xr
import numpy as np

from .enkf_observation_implementation_type import EnkfObservationImplementationType
from .general_observation import GenObservation
Expand All @@ -12,6 +12,8 @@
if TYPE_CHECKING:
from datetime import datetime

import polars


@dataclass
class ObsVector:
Expand All @@ -27,28 +29,38 @@ def __iter__(self) -> Iterable[Union[SummaryObservation, GenObservation]]:
def __len__(self) -> int:
return len(self.observations)

def to_dataset(self, active_list: List[int]) -> xr.Dataset:
def to_dataset(self, active_list: List[int]) -> polars.DataFrame:
if self.observation_type == EnkfObservationImplementationType.GEN_OBS:
datasets = []
dataframes = []
for time_step, node in self.observations.items():
if active_list and time_step not in active_list:
continue

assert isinstance(node, GenObservation)
datasets.append(
xr.Dataset(
dataframes.append(
polars.DataFrame(
{
"observations": (["report_step", "index"], [node.values]),
"std": (["report_step", "index"], [node.stds]),
},
coords={"index": node.indices, "report_step": [time_step]},
"response_key": self.data_key,
"observation_key": self.observation_key,
"report_step": polars.Series(
np.full(len(node.indices), time_step),
dtype=polars.UInt16,
),
"index": polars.Series(node.indices, dtype=polars.UInt16),
"observations": polars.Series(
node.values, dtype=polars.Float32
),
"std": polars.Series(node.stds, dtype=polars.Float32),
}
)
)
combined = xr.combine_by_coords(datasets)
combined.attrs["response"] = self.data_key
return combined # type: ignore

combined = polars.concat(dataframes)
return combined
elif self.observation_type == EnkfObservationImplementationType.SUMMARY_OBS:
observations = []
actual_response_key = self.observation_key
actual_observation_keys = []
errors = []
dates = list(self.observations.keys())
if active_list:
Expand All @@ -57,15 +69,20 @@ def to_dataset(self, active_list: List[int]) -> xr.Dataset:
for time_step in dates:
n = self.observations[time_step]
assert isinstance(n, SummaryObservation)
actual_observation_keys.append(n.observation_key)
observations.append(n.value)
errors.append(n.std)
return xr.Dataset(

dates_series = polars.Series(dates).dt.cast_time_unit("ms")

return polars.DataFrame(
{
"observations": (["name", "time"], [observations]),
"std": (["name", "time"], [errors]),
},
coords={"time": dates, "name": [self.observation_key]},
attrs={"response": "summary"},
"response_key": actual_response_key,
"observation_key": actual_observation_keys,
"time": dates_series,
"observations": polars.Series(observations, dtype=polars.Float32),
"std": polars.Series(errors, dtype=polars.Float32),
}
)
else:
raise ValueError(f"Unknown observation type {self.observation_type}")
Loading

0 comments on commit 3e1b012

Please sign in to comment.