-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save & load responses as parquet #8684
Merged
+2,221
−1,766
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
|
||
import iterative_ensemble_smoother as ies | ||
import numpy as np | ||
import pandas as pd | ||
import polars | ||
import psutil | ||
from iterative_ensemble_smoother.experimental import ( | ||
AdaptiveESMDA, | ||
|
@@ -153,56 +153,75 @@ def _get_observations_and_responses( | |
observation_values = [] | ||
observation_errors = [] | ||
indexes = [] | ||
observations = ensemble.experiment.observations | ||
for obs in selected_observations: | ||
observation = observations[obs] | ||
group = observation.attrs["response"] | ||
all_responses = ensemble.load_responses(group, tuple(iens_active_index)) | ||
if "time" in observation.coords: | ||
all_responses = all_responses.reindex( | ||
time=observation.time, | ||
method="nearest", | ||
observations_by_type = ensemble.experiment.observations | ||
for ( | ||
response_type, | ||
response_cls, | ||
) in ensemble.experiment.response_configuration.items(): | ||
if response_type not in observations_by_type: | ||
continue | ||
|
||
observations_for_type = observations_by_type[response_type].filter( | ||
polars.col("observation_key").is_in(list(selected_observations)) | ||
) | ||
responses_for_type = ensemble.load_responses( | ||
response_type, realizations=tuple(iens_active_index) | ||
) | ||
|
||
# Note that if there are duplicate entries for one | ||
# response at one index, they are aggregated together | ||
# with "mean" by default | ||
pivoted = responses_for_type.pivot( | ||
on="realization", | ||
index=["response_key", *response_cls.primary_key], | ||
aggregate_function="mean", | ||
) | ||
|
||
# Note2reviewer: | ||
# We need to either assume that if there is a time column | ||
# we will approx-join that, or we could specify in response configs | ||
# that there is a column that requires an approx "asof" join. | ||
# Suggest we simplify and assume that there is always only | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree, if and when we add new response types where this might be relevant we can add it then. |
||
# one "time" column, which we will reindex towards the response dataset | ||
# with a given resolution | ||
if "time" in pivoted: | ||
joined = observations_for_type.join_asof( | ||
pivoted, | ||
by=["response_key", *response_cls.primary_key], | ||
on="time", | ||
tolerance="1s", | ||
) | ||
try: | ||
observations_and_responses = observation.merge(all_responses, join="left") | ||
except KeyError as e: | ||
raise ErtAnalysisError( | ||
f"Mismatched index for: " | ||
f"Observation: {obs} attached to response: {group}" | ||
) from e | ||
|
||
observation_keys.append([obs] * observations_and_responses["observations"].size) | ||
|
||
if group == "summary": | ||
indexes.append( | ||
[ | ||
np.datetime_as_string(e, unit="s") | ||
for e in observations_and_responses["time"].data | ||
] | ||
) | ||
else: | ||
indexes.append( | ||
[ | ||
f"{e[0]}, {e[1]}" | ||
for e in zip( | ||
list(observations_and_responses["report_step"].data) | ||
* len(observations_and_responses["index"].data), | ||
observations_and_responses["index"].data, | ||
) | ||
] | ||
joined = observations_for_type.join( | ||
pivoted, | ||
how="left", | ||
on=["response_key", *response_cls.primary_key], | ||
) | ||
|
||
observation_values.append( | ||
observations_and_responses["observations"].data.ravel() | ||
) | ||
observation_errors.append(observations_and_responses["std"].data.ravel()) | ||
joined = joined.sort(by="observation_key") | ||
|
||
index_1d = joined.with_columns( | ||
polars.concat_str(response_cls.primary_key, separator=", ").alias("index") | ||
)["index"].to_numpy() | ||
|
||
obs_keys_1d = joined["observation_key"].to_numpy() | ||
obs_values_1d = joined["observations"].to_numpy() | ||
obs_errors_1d = joined["std"].to_numpy() | ||
|
||
# 4 columns are always there: | ||
# [ response_key, observation_key, observations, std ] | ||
# + one column per "primary key" column | ||
num_non_response_value_columns = 4 + len(response_cls.primary_key) | ||
responses = joined.select( | ||
joined.columns[num_non_response_value_columns:] | ||
).to_numpy() | ||
|
||
filtered_responses.append(responses) | ||
observation_keys.append(obs_keys_1d) | ||
observation_values.append(obs_values_1d) | ||
observation_errors.append(obs_errors_1d) | ||
indexes.append(index_1d) | ||
|
||
filtered_responses.append( | ||
observations_and_responses["values"] | ||
.transpose(..., "realization") | ||
.values.reshape((-1, len(observations_and_responses.realization))) | ||
) | ||
ensemble.load_responses.cache_clear() | ||
return ( | ||
np.concatenate(filtered_responses), | ||
|
@@ -288,12 +307,14 @@ def _load_observations_and_responses( | |
scaling[obs_group_mask] *= scaling_factors | ||
|
||
scaling_factors_dfs.append( | ||
pd.DataFrame( | ||
data={ | ||
polars.DataFrame( | ||
{ | ||
"input_group": [", ".join(input_group)] * len(scaling_factors), | ||
"index": indexes[obs_group_mask], | ||
"obs_key": obs_keys[obs_group_mask], | ||
"scaling_factor": scaling_factors, | ||
"scaling_factor": polars.Series( | ||
scaling_factors, dtype=polars.Float32 | ||
), | ||
} | ||
) | ||
) | ||
|
@@ -322,10 +343,8 @@ def _load_observations_and_responses( | |
) | ||
) | ||
|
||
scaling_factors_df = pd.concat(scaling_factors_dfs).set_index( | ||
["input_group", "obs_key", "index"], verify_integrity=True | ||
) | ||
ensemble.save_observation_scaling_factors(scaling_factors_df.to_xarray()) | ||
scaling_factors_df = polars.concat(scaling_factors_dfs) | ||
ensemble.save_observation_scaling_factors(scaling_factors_df) | ||
|
||
# Recompute with updated scales | ||
scaled_errors = errors * scaling | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
from typing import List, Optional, Tuple | ||
|
||
import numpy as np | ||
import xarray as xr | ||
import polars | ||
from typing_extensions import Self | ||
|
||
from ert.validation import rangestring_to_list | ||
|
@@ -107,21 +107,23 @@ def from_config_dict(cls, config_dict: ConfigDict) -> Optional[Self]: | |
report_steps_list=report_steps, | ||
) | ||
|
||
def read_from_file(self, run_path: str, _: int) -> xr.Dataset: | ||
def _read_file(filename: Path, report_step: int) -> xr.Dataset: | ||
def read_from_file(self, run_path: str, _: int) -> polars.DataFrame: | ||
def _read_file(filename: Path, report_step: int) -> polars.DataFrame: | ||
if not filename.exists(): | ||
raise ValueError(f"Missing output file: {filename}") | ||
data = np.loadtxt(_run_path / filename, ndmin=1) | ||
active_information_file = _run_path / (str(filename) + "_active") | ||
if active_information_file.exists(): | ||
active_list = np.loadtxt(active_information_file) | ||
data[active_list == 0] = np.nan | ||
return xr.Dataset( | ||
{"values": (["report_step", "index"], [data])}, | ||
coords={ | ||
"index": np.arange(len(data)), | ||
"report_step": [report_step], | ||
}, | ||
return polars.DataFrame( | ||
{ | ||
"report_step": polars.Series( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This made it much easier to read! |
||
np.full(len(data), report_step), dtype=polars.UInt16 | ||
), | ||
"index": polars.Series(np.arange(len(data)), dtype=polars.UInt16), | ||
"values": polars.Series(data, dtype=polars.Float32), | ||
} | ||
) | ||
|
||
errors = [] | ||
|
@@ -150,16 +152,16 @@ def _read_file(filename: Path, report_step: int) -> xr.Dataset: | |
except ValueError as err: | ||
errors.append(str(err)) | ||
|
||
ds_all_report_steps = xr.concat( | ||
datasets_per_report_step, dim="report_step" | ||
).expand_dims(name=[name]) | ||
ds_all_report_steps = polars.concat(datasets_per_report_step) | ||
ds_all_report_steps.insert_column( | ||
0, polars.Series("response_key", [name] * len(ds_all_report_steps)) | ||
) | ||
datasets_per_name.append(ds_all_report_steps) | ||
|
||
if errors: | ||
raise ValueError(f"Error reading GEN_DATA: {self.name}, errors: {errors}") | ||
|
||
combined = xr.concat(datasets_per_name, dim="name") | ||
combined.attrs["response"] = "gen_data" | ||
combined = polars.concat(datasets_per_name) | ||
return combined | ||
|
||
def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]]]: | ||
|
@@ -173,5 +175,9 @@ def get_args_for_key(self, key: str) -> Tuple[Optional[str], Optional[List[int]] | |
def response_type(self) -> str: | ||
return "gen_data" | ||
|
||
@property | ||
def primary_key(self) -> List[str]: | ||
return ["report_step", "index"] | ||
|
||
|
||
responses_index.add_response_type(GenDataConfig) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the implication of
mean
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It said so in the comment 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will that be output somewhere? Is it possible to for example log it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is for the edge case where we end up with duplicate values for one response at one index, for example a given time. In that case, we need to aggregate them for the pivoted table to make sense, else the
index
used to pivot contains duplicates. So taking the average of the duplicate response values on the timestep seems to be somewhat "close enough" to do what we want, we could set it to use min,max,median,first, etc, could configure it, but not sure if it would be interesting to users to do this?Example from running
test_that_duplicate_summary_time_steps_does_not_fail
:Alternatively we could strive to achieve something like this:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be logged / given as a warning somehow, I'm not so familiar with when/why it happens, which may be relevant to what the warning/logging message should be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Performance-wise it might be slow to always check if some values were aggregated, or a naive try-catch around the pivot, as it will pass if there are no duplicate values)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is a good, somewhat performant way of warning the user this has happened, that would be good. My hunch is that this would typically happen in pressure tests where the time resolution is quite high, and the simulator does not have the same resolution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be OK to do this in a separate PR? I think the try-catch, first trying without an aggregation, then trying with one, should be easy to add / easy to remove if it turns out to have bad side effects. Should maybe be tested as its own thing just to be sure.