Skip to content

Commit

Permalink
Remove dependency on ropt config in forward model evaluations
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Jan 8, 2025
1 parent eb9b04a commit 7127528
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 28 deletions.
42 changes: 14 additions & 28 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def _forward_model_evaluator(
# Gather the results and create the result for ropt:
results = self._gather_simulation_results(ensemble)
evaluator_result = self._make_evaluator_result(
control_values, evaluator_context, batch_data, results, cached_results
control_values, batch_data, results, cached_results
)

# Add the results from the evaluations to the cache:
Expand All @@ -406,18 +406,17 @@ def _get_cached_results(
) -> dict[int, Any]:
cached_results: dict[int, Any] = {}
if self._simulator_cache is not None:
assert evaluator_context.config.realizations.names is not None
for control_idx, real_idx in enumerate(evaluator_context.realizations):
cached_data = self._simulator_cache.get(
int(evaluator_context.config.realizations.names[real_idx]),
self._everest_config.model.realizations[real_idx],
control_values[control_idx, :],
)
if cached_data is not None:
cached_results[control_idx] = cached_data
return cached_results

@staticmethod
def _init_batch_data(
self,
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
cached_results: dict[int, Any],
Expand Down Expand Up @@ -447,9 +446,8 @@ def add_control(
or evaluator_context.active[evaluator_context.realizations[control_idx]]
):
controls: dict[str, Any] = {}
assert evaluator_context.config.variables.names is not None
for control_name, control_value in zip(
evaluator_context.config.variables.names,
self._everest_config.control_name_tuples,
control_values[control_idx, :],
strict=False,
):
Expand Down Expand Up @@ -520,14 +518,11 @@ def _get_run_args(
substitutions = self.ert_config.substitutions
substitutions["<BATCH_NAME>"] = ensemble.name
self.active_realizations = [True] * len(batch_data)
assert evaluator_context.config.realizations.names is not None
for sim_id, control_idx in enumerate(batch_data.keys()):
substitutions[f"<GEO_ID_{sim_id}_0>"] = str(
int(
evaluator_context.config.realizations.names[
evaluator_context.realizations[control_idx]
]
)
self._everest_config.model.realizations[
evaluator_context.realizations[control_idx]
]
)
run_paths = Runpaths(
jobname_format=self.ert_config.model_config.jobname_format_string,
Expand Down Expand Up @@ -587,26 +582,20 @@ def _gather_simulation_results(
def _make_evaluator_result(
self,
control_values: NDArray[np.float64],
evaluator_context: EvaluatorContext,
batch_data: dict[int, Any],
results: list[dict[str, NDArray[np.float64]]],
cached_results: dict[int, Any],
) -> EvaluatorResult:
# We minimize the negative of the objectives:
assert evaluator_context.config.objectives.names is not None
objectives = -self._get_simulation_results(
results,
evaluator_context.config.objectives.names,
control_values,
batch_data,
results, self._everest_config.objective_names, control_values, batch_data
)

constraints = None
if evaluator_context.config.nonlinear_constraints is not None:
assert evaluator_context.config.nonlinear_constraints.names is not None
if self._everest_config.output_constraints:
constraints = self._get_simulation_results(
results,
evaluator_context.config.nonlinear_constraints.names,
self._everest_config.constraint_names,
control_values,
batch_data,
)
Expand All @@ -633,7 +622,7 @@ def _make_evaluator_result(
@staticmethod
def _get_simulation_results(
results: list[dict[str, NDArray[np.float64]]],
names: tuple[str],
names: list[str],
controls: NDArray[np.float64],
batch_data: dict[int, Any],
) -> NDArray[np.float64]:
Expand All @@ -655,14 +644,11 @@ def _add_results_to_cache(
constraints: NDArray[np.float64] | None,
) -> None:
if self._simulator_cache is not None:
assert evaluator_context.config.realizations.names is not None
for control_idx in batch_data:
self._simulator_cache.add(
int(
evaluator_context.config.realizations.names[
evaluator_context.realizations[control_idx]
]
),
self._everest_config.model.realizations[
evaluator_context.realizations[control_idx]
],
control_values[control_idx, ...],
objectives[control_idx, ...],
None if constraints is None else constraints[control_idx, ...],
Expand Down
43 changes: 43 additions & 0 deletions src/everest/config/everest_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from argparse import ArgumentParser
from functools import cached_property
from io import StringIO
from itertools import chain
from pathlib import Path
Expand Down Expand Up @@ -630,6 +631,48 @@ def control_names(self):
controls = self.controls or []
return [control.name for control in controls]

@cached_property
def control_name_tuples(self) -> list[tuple[str, str, int | tuple[str, str]]]:
tuples = []
for control in self.controls:
for variable in control.variables:
if isinstance(variable, ControlVariableGuessListConfig):
for index in range(1, len(variable.initial_guess) + 1):
tuples.append((control.name, variable.name, index))
elif variable.index is not None:
tuples.append((control.name, variable.name, variable.index))
else:
tuples.append((control.name, variable.name))
return tuples

@property
def objective_names(self) -> list[str]:
return [objective.name for objective in self.objective_functions]

@cached_property
def constraint_names(self) -> list[str]:
names: list[str] = []

def _add_output_constraint(rhs_value: float | None, suffix=None):
if rhs_value is not None:
name = constr.name
names.append(name if suffix is None else f"{name}:{suffix}")

for constr in self.output_constraints or []:
_add_output_constraint(
constr.target,
)
_add_output_constraint(
constr.upper_bound,
None if constr.lower_bound is None else "upper",
)
_add_output_constraint(
constr.lower_bound,
None if constr.upper_bound is None else "lower",
)

return names

@property
def result_names(self):
objectives_names = [
Expand Down

0 comments on commit 7127528

Please sign in to comment.