Skip to content

Commit

Permalink
Refactor rft export logic for polars
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen authored and yngve-sk committed Sep 30, 2024
1 parent 39a0df8 commit d7ddd62
Showing 1 changed file with 65 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy
import pandas as pd
import polars
from qtpy.QtWidgets import QCheckBox

from ert.config import CancelPluginException, ErtPlugin
Expand Down Expand Up @@ -85,8 +86,6 @@ def run(
ensemble_data_as_json = None if len(workflow_args) < 3 else workflow_args[2]
drop_const_cols = False if len(workflow_args) < 4 else bool(workflow_args[3])

wells = set()

ensemble_data_as_dict = (
json.loads(ensemble_data_as_json) if ensemble_data_as_json else {}
)
Expand All @@ -110,82 +109,96 @@ def run(
f"The ensemble '{ensemble_name}' does not have any data!"
)

obs = ensemble.experiment.observations
obs_df = ensemble.experiment.observations.get("gen_data")
obs_keys = []
for key, _ in obs.items():
for key in ensemble.experiment.observation_keys:
if key.startswith("RFT_"):
obs_keys.append(key)

if len(obs_keys) == 0:
if len(obs_keys) == 0 or obs_df is None:
raise UserWarning(
"The config does not contain any"
" GENERAL_OBSERVATIONS starting with RFT_*"
)

ensemble_data = []
for obs_key in obs_keys:
well = obs_key.replace("RFT_", "")
wells.add(well)
obs_vector = obs[obs_key]
data_key = obs_vector.attrs["response"]
if len(obs_vector.report_step) == 1:
report_step = obs_vector.report_step.values
obs_node = obs_vector.sel(report_step=report_step)
else:
well_key = obs_key.replace("RFT_", "")

obs_df = obs_df.filter(polars.col("observation_key").eq(obs_key))
response_key = obs_df["response_key"].unique().to_list()[0]

if len(obs_df["report_step"].unique()) != 1:
raise UserWarning(
"GEN_DATA RFT CSV Export can only be used for observations "
"active for exactly one report step"
)

realizations = ensemble.get_realization_list_with_responses(data_key)
vals = ensemble.load_responses(data_key, tuple(realizations)).sel(
report_step=report_step, drop=True
)
index = pd.Index(vals.index.values, name="axis")
rft_data = pd.DataFrame(
data=vals["values"].values.reshape(len(vals.realization), -1).T,
index=index,
columns=realizations,
realizations = ensemble.get_realization_list_with_responses(
response_key
)
responses = ensemble.load_responses(response_key, tuple(realizations))
joined = obs_df.join(
responses,
on=["response_key", "report_step", "index"],
how="left",
).drop("index", "report_step")

# Trajectory
trajectory_file = os.path.join(trajectory_path, f"{well}.txt")
trajectory_file = os.path.join(trajectory_path, f"{well_key}.txt")
if not os.path.isfile(trajectory_file):
trajectory_file = os.path.join(trajectory_path, f"{well}_R.txt")
trajectory_file = os.path.join(trajectory_path, f"{well_key}_R.txt")

arg = load_args(
trajectory_file, column_names=["utm_x", "utm_y", "md", "tvd"]
)
tvd_arg = arg["tvd"]

# Observations
for iens in realizations:
realization_frame = pd.DataFrame(
data={
"TVD": tvd_arg,
"Pressure": rft_data[iens],
"ObsValue": obs_node["observations"].values[0],
"ObsStd": obs_node["std"].values[0],
},
columns=["TVD", "Pressure", "ObsValue", "ObsStd"],
)

realization_frame["Realization"] = iens
realization_frame["Well"] = well
realization_frame["Ensemble"] = ensemble_name
realization_frame["Iteration"] = ensemble.iteration

ensemble_data.append(realization_frame)

data.append(pd.concat(ensemble_data))

frame = pd.concat(data)
frame.set_index(["Realization", "Well", "Ensemble", "Iteration"], inplace=True)
if drop_const_cols:
frame = frame.loc[:, (frame != frame.iloc[0]).any()]
all_realization_frames = joined.rename(
{
"realization": "Realization",
"values": "Pressure",
"observations": "ObsValue",
"std": "ObsStd",
}
).with_columns(
[
polars.lit(well_key).alias("Well").cast(polars.String),
polars.lit(ensemble.name).alias("Ensemble").cast(polars.String),
polars.lit(ensemble.iteration)
.alias("Iteration")
.cast(polars.UInt8),
polars.lit(tvd_arg).alias("TVD").cast(polars.Float32),
]
)

frame.to_csv(output_file)
well_list_str = ", ".join(list(wells))
data.append(all_realization_frames)

frame = polars.concat(data)

cols_index = ["Well", "Ensemble", "Iteration"]
const_cols_right = ["ObsValue", "ObsStd"]
const_cols_left = [
col
for col in frame.columns
if (
col not in cols_index
and col not in const_cols_right
and frame[col].n_unique() == 1
)
]

columns_to_export = [
"Realization",
*cols_index,
*(const_cols_left if not drop_const_cols else []),
*["Pressure"],
*(const_cols_right if not drop_const_cols else []),
]

to_export = frame.select(columns_to_export)

to_export.write_csv(output_file, include_header=True)
well_list_str = ", ".join(to_export["Well"].unique().to_list())
export_info = (
f"Exported RFT information for wells: {well_list_str} to: {output_file}"
)
Expand Down

0 comments on commit d7ddd62

Please sign in to comment.