Skip to content

Commit

Permalink
f
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Dec 18, 2024
1 parent 3c9d871 commit 1d2b4f8
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 56 deletions.
90 changes: 43 additions & 47 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from everest.strings import EVEREST

from ..run_arg import RunArg, create_run_arguments
from .base_run_model import BaseRunModel, StatusEvents
from .base_run_model import BaseRunModel

if TYPE_CHECKING:
from ert.storage import Ensemble, Experiment
Expand Down Expand Up @@ -116,9 +116,8 @@ def __init__(
"add the above random seed to your configuration file."
)

self.ropt_config = everest2ropt(everest_config)
self.everest_config = everest_config
self.support_restart = False
self._ropt_config = everest2ropt(everest_config)
self._everest_config = everest_config

self._sim_callback = simulation_callback
self._opt_callback = optimization_callback
Expand All @@ -134,19 +133,18 @@ def __init__(
else None
)
self._experiment: Experiment | None = None
self.eval_server_cfg: EvaluatorServerConfig | None = None
storage = open_storage(config.ens_path, mode="w")
status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue()
self.batch_id: int = 0
self.status: SimulationStatus | None = None
self._eval_server_cfg: EvaluatorServerConfig | None = None
self._batch_id: int = 0
self._status: SimulationStatus | None = None

super().__init__(
config,
storage,
open_storage(config.ens_path, mode="w"),
config.queue_config,
status_queue,
queue.SimpleQueue(),
active_realizations=[], # Set dynamically in run_forward_model()
)
self.support_restart = False

@classmethod
def create(
Expand All @@ -155,9 +153,7 @@ def create(
simulation_callback: SimulationCallback | None = None,
optimization_callback: OptimizerCallback | None = None,
) -> EverestRunModel:
def default_simulation_callback(
simulation_status: SimulationStatus | None,
) -> str | None:
def default_simulation_callback(_: SimulationStatus | None) -> str | None:
return None

def default_optimization_callback() -> str | None:
Expand All @@ -168,8 +164,9 @@ def default_optimization_callback() -> str | None:
config=ert_config,
everest_config=ever_config,
simulation_callback=simulation_callback or default_simulation_callback,

Check failure on line 166 in src/ert/run_models/everest_run_model.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Argument "simulation_callback" to "EverestRunModel" has incompatible type "SimulationCallback | Callable[[SimulationStatus | None], str | None]"; expected "SimulationCallback"
optimization_callback=optimization_callback
or default_optimization_callback,
optimization_callback=(
optimization_callback or default_optimization_callback
),
)

@classmethod
Expand All @@ -189,15 +186,15 @@ def result(self) -> OptimalResult | None:
return self._result

def __repr__(self) -> str:
config_json = json.dumps(self.everest_config, sort_keys=True, indent=2)
config_json = json.dumps(self._everest_config, sort_keys=True, indent=2)
return f"EverestRunModel(config={config_json})"

def run_experiment(
self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False
) -> None:
self.log_at_startup()
self.restart = restart
self.eval_server_cfg = evaluator_server_config
self._eval_server_cfg = evaluator_server_config
self._experiment = self._storage.create_experiment(
name=f"EnOpt@{datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=self.ert_config.ensemble_config.parameter_configuration,
Expand All @@ -214,7 +211,7 @@ def run_experiment(
# This mechanism is outdated and not supported by the ropt package. It
# is retained for now via the seba_sqlite package.
seba_storage = SqliteStorage( # type: ignore
optimizer, self.everest_config.optimization_output_dir
optimizer, self._everest_config.optimization_output_dir
)

# Run the optimization:
Expand Down Expand Up @@ -284,10 +281,10 @@ def _create_optimizer(self) -> BasicOptimizer:
# simplifying code that reads them as fixed width tables. `maximize` is
# set because ropt reports minimization results, while everest wants
# maximization results, necessitating a conversion step.
ropt_output_folder = Path(self.everest_config.optimization_output_dir)
ropt_output_folder = Path(self._everest_config.optimization_output_dir)
optimizer = (
BasicOptimizer(
enopt_config=self.ropt_config, evaluator=self._forward_model_evaluator
enopt_config=self._ropt_config, evaluator=self._forward_model_evaluator
)
.add_table(
columns=RESULT_COLUMNS,
Expand Down Expand Up @@ -347,9 +344,9 @@ def _on_before_forward_model_evaluation(
logging.getLogger(EVEREST).debug("Optimization callback called")

if (
self.everest_config.optimization is not None
and self.everest_config.optimization.max_batch_num is not None
and (self.batch_id >= self.everest_config.optimization.max_batch_num)
self._everest_config.optimization is not None
and self._everest_config.optimization.max_batch_num is not None
and (self._batch_id >= self._everest_config.optimization.max_batch_num)
):
self._exit_code = EverestExitCode.MAX_BATCH_NUM_REACHED
logging.getLogger(EVEREST).info("Maximum number of batches reached")
Expand All @@ -365,7 +362,7 @@ def _forward_model_evaluator(
self, control_values: NDArray[np.float64], evaluator_context: EvaluatorContext
) -> EvaluatorResult:
# Reset the current run status:
self.status = None
self._status = None

# Get any cached_results results that may be useful:
cached_results = self._get_cached_results(control_values, evaluator_context)
Expand All @@ -378,7 +375,7 @@ def _forward_model_evaluator(
# Initialize a new experiment in storage:
assert self._experiment
ensemble = self._experiment.create_ensemble(
name=f"batch_{self.batch_id}",
name=f"batch_{self._batch_id}",
ensemble_size=len(case_data),
)
for sim_id, controls in enumerate(case_data.values()):
Expand All @@ -393,8 +390,8 @@ def _forward_model_evaluator(
"_ERT_SIMULATION_MODE": "batch_simulation",
}
)
assert self.eval_server_cfg
self._evaluate_and_postprocess(run_args, ensemble, self.eval_server_cfg)
assert self._eval_server_cfg
self._evaluate_and_postprocess(run_args, ensemble, self._eval_server_cfg)

# If necessary, delete the run path:
self._delete_runpath(run_args)
Expand All @@ -406,7 +403,7 @@ def _forward_model_evaluator(
)

# Increase the batch ID for the next evaluation:
self.batch_id += 1
self._batch_id += 1

# Add the results from the evaluations to the cache:
self._add_results_to_cache(
Expand Down Expand Up @@ -526,7 +523,7 @@ def _check_suffix(
f"Key {key} has suffixes, a suffix must be specified"
)

if set(controls.keys()) != set(self.everest_config.control_names):
if set(controls.keys()) != set(self._everest_config.control_names):
err_msg = "Mismatch between initialized and provided control names."
raise KeyError(err_msg)

Expand Down Expand Up @@ -562,11 +559,10 @@ def _slug(entity: str) -> str:
self.active_realizations = [True] * len(case_data)
assert evaluator_context.config.realizations.names is not None
for sim_id, control_idx in enumerate(case_data.keys()):
if self.active_realizations[sim_id]:
realization = evaluator_context.realizations[control_idx]
substitutions[f"<GEO_ID_{sim_id}_0>"] = str(
evaluator_context.config.realizations.names[realization]
)
realization = evaluator_context.realizations[control_idx]
substitutions[f"<GEO_ID_{sim_id}_0>"] = str(
evaluator_context.config.realizations.names[realization]
)

run_paths = Runpaths(
jobname_format=self.ert_config.model_config.jobname_format_string,
Expand All @@ -584,8 +580,8 @@ def _slug(entity: str) -> str:
def _delete_runpath(self, run_args: list[RunArg]) -> None:
logging.getLogger(EVEREST).debug("Simulation callback called")
if (
self.everest_config.simulator is not None
and self.everest_config.simulator.delete_run_path
self._everest_config.simulator is not None
and self._everest_config.simulator.delete_run_path
):
for i, real in self.get_current_snapshot().reals.items():
path_to_delete = run_args[int(i)].runpath
Expand Down Expand Up @@ -614,11 +610,11 @@ def _gather_results(
results.append({})
continue
d = {}
for key in self.everest_config.result_names:
for key in self._everest_config.result_names:
data = ensemble.load_responses(key, (sim_id,))
d[key] = data["values"].to_numpy()
results.append(d)
for fnc_name, alias in self.everest_config.function_aliases.items():
for fnc_name, alias in self._everest_config.function_aliases.items():
for result in results:
result[fnc_name] = result[alias]
return results
Expand Down Expand Up @@ -663,7 +659,7 @@ def _get_evaluator_result(
return EvaluatorResult(
objectives=objectives,
constraints=constraints,
batch_id=self.batch_id,
batch_id=self._batch_id,
evaluation_ids=sim_ids,
)

Expand All @@ -688,18 +684,18 @@ def _add_results_to_cache(

def check_if_runpath_exists(self) -> bool:
return (
self.everest_config.simulation_dir is not None
and os.path.exists(self.everest_config.simulation_dir)
and any(os.listdir(self.everest_config.simulation_dir))
self._everest_config.simulation_dir is not None
and os.path.exists(self._everest_config.simulation_dir)
and any(os.listdir(self._everest_config.simulation_dir))
)

def send_snapshot_event(self, event: Event, iteration: int) -> None:
super().send_snapshot_event(event, iteration)
if type(event) in (EESnapshot, EESnapshotUpdate):
newstatus = self._simulation_status(self.get_current_snapshot())
if self.status != newstatus: # No change in status
if self._status != newstatus: # No change in status
self._sim_callback(newstatus)
self.status = newstatus
self._status = newstatus

def _simulation_status(self, snapshot: EnsembleSnapshot) -> SimulationStatus:
jobs_progress: list[list[JobProgress]] = []
Expand All @@ -724,7 +720,7 @@ def _simulation_status(self, snapshot: EnsembleSnapshot) -> SimulationStatus:
)
if fm_step.get("error", ""):
self._handle_errors(
batch=self.batch_id,
batch=self._batch_id,
simulation=simulation,
realization=realization,
fm_name=fm_step.get("name", "Unknown"), # type: ignore
Expand All @@ -735,7 +731,7 @@ def _simulation_status(self, snapshot: EnsembleSnapshot) -> SimulationStatus:
return {
"status": self.get_current_status(),
"progress": jobs_progress,
"batch_number": self.batch_id,
"batch_number": self._batch_id,
}

def _handle_errors(
Expand Down
8 changes: 1 addition & 7 deletions tests/everest/test_everest_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,7 @@ async def test_everest_output(copy_mocked_test_data_to_tmp):
initial_folders = set(folders)
initial_files = set(files)

# Tests in this class used to fail when a callback was passed in
# Use a callback just to see that everything works fine, even though
# the callback does nothing
def useless_cb(*args, **kwargs):
pass

EverestRunModel.create(config, optimization_callback=useless_cb)
EverestRunModel.create(config)

# Check the output folder is created when stating the optimization
# in everest workflow
Expand Down
2 changes: 1 addition & 1 deletion tests/everest/test_everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_everserver_status_failure(_1, copy_math_func_test_data_to_tmp):
"ert.run_models.everest_run_model.EverestRunModel.run_experiment",
autospec=True,
side_effect=lambda self, evaluator_server_config, restart=False: check_status(
ServerConfig.get_hostfile_path(self.everest_config.output_dir),
ServerConfig.get_hostfile_path(self._everest_config.output_dir),
status=ServerStatus.running,
),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/everest/test_simulator_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def new_call(*args):
Path("everest_output/optimization_output/seba.db").unlink()

# The batch_id was used as a stopping criterion, so it must be reset:
run_model.batch_id = 0
run_model._batch_id = 0

run_model.run_experiment(evaluator_server_config)
assert n_evals == 0
Expand Down

0 comments on commit 1d2b4f8

Please sign in to comment.