From d07144e005f5249daade4a5797bdad0672a91f14 Mon Sep 17 00:00:00 2001 From: Peter Verveer Date: Wed, 18 Dec 2024 07:49:43 +0000 Subject: [PATCH] Update exit_code code --- src/ert/run_models/everest_run_model.py | 38 ++++++++++++++-------- src/everest/detached/jobs/everserver.py | 43 ++++++++++++++----------- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/src/ert/run_models/everest_run_model.py b/src/ert/run_models/everest_run_model.py index 0437ba389c9..8ff1a87ea4b 100644 --- a/src/ert/run_models/everest_run_model.py +++ b/src/ert/run_models/everest_run_model.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +from enum import IntEnum import functools import json import logging @@ -72,6 +73,14 @@ class OptimizerCallback(Protocol): def __call__(self) -> str | None: ... +class EverestExitCode(IntEnum): + COMPLETED = 1 + TOO_FEW_REALIZATIONS = 2 + MAX_FUNCTIONS_REACHED = 3 + MAX_BATCH_NUM_REACHED = 4 + USER_ABORT = 5 + + @dataclass class OptimalResult: batch: int @@ -120,10 +129,7 @@ def __init__( self._fm_errors: dict[int, dict[str, Any]] = {} self._display_all_jobs = display_all_jobs self._result: OptimalResult | None = None - self._exit_code: Literal["max_batch_num_reached"] | OptimizerExitCode | None = ( - None - ) - self._max_batch_num_reached = False + self._exit_code: EverestExitCode | None = None self._evaluator_cache: _EvaluatorCache | None = None if ( everest_config.simulator is not None @@ -180,9 +186,7 @@ def description(cls) -> str: return "Run batches " @property - def exit_code( - self, - ) -> Literal["max_batch_num_reached"] | OptimizerExitCode | None: + def exit_code(self) -> EverestExitCode | None: return self._exit_code @property @@ -226,11 +230,8 @@ def run_experiment( seba_storage.get_optimal_result() # type: ignore ) - self._exit_code = ( - "max_batch_num_reached" - if self._max_batch_num_reached - else optimizer_exit_code - ) + if self._exit_code is None: + self._exit_code = self._get_exit_code(optimizer_exit_code) def _create_optimizer(self) -> BasicOptimizer: RESULT_COLUMNS = { @@ -332,6 +333,17 @@ def _create_optimizer(self) -> BasicOptimizer: return optimizer + def _get_exit_code(self, optimizer_exit_code: OptimizerExitCode) -> EverestExitCode: + match optimizer_exit_code: + case OptimizerExitCode.MAX_FUNCTIONS_REACHED: + return EverestExitCode.MAX_FUNCTIONS_REACHED + case OptimizerExitCode.USER_ABORT: + return EverestExitCode.USER_ABORT + case OptimizerExitCode.TOO_FEW_REALIZATIONS: + return EverestExitCode.TOO_FEW_REALIZATIONS + case _: + return EverestExitCode.COMPLETED + def _on_before_forward_model_evaluation( self, _: OptimizerEvent, optimizer: BasicOptimizer ) -> None: @@ -342,7 +354,7 @@ def _on_before_forward_model_evaluation( and self.everest_config.optimization.max_batch_num is not None and (self.batch_id >= self.everest_config.optimization.max_batch_num) ): - self._max_batch_num_reached = True + self._exit_code = EverestExitCode.MAX_BATCH_NUM_REACHED logging.getLogger(EVEREST).info("Maximum number of batches reached") optimizer.abort_optimization() if ( diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 936ccedf236..02ae28a1172 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -30,11 +30,10 @@ HTTPBasic, HTTPBasicCredentials, ) -from ropt.enums import OptimizerExitCode from ert.config import QueueSystem from ert.ensemble_evaluator import EvaluatorServerConfig -from ert.run_models.everest_run_model import EverestRunModel +from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel from everest import export_to_csv, export_with_progress from everest.config import EverestConfig, ServerConfig from everest.detached import ServerStatus, get_opt_status, update_everserver_status @@ -373,25 +372,31 @@ def main(): def _get_optimization_status(exit_code, shared_data): - if exit_code == "max_batch_num_reached": - return ServerStatus.completed, "Maximum number of batches reached." - - if exit_code == OptimizerExitCode.MAX_FUNCTIONS_REACHED: - return ServerStatus.completed, "Maximum number of function evaluations reached." + match exit_code: + case EverestExitCode.MAX_BATCH_NUM_REACHED: + return ServerStatus.completed, "Maximum number of batches reached." + + case EverestExitCode.MAX_FUNCTIONS_REACHED: + return ( + ServerStatus.completed, + "Maximum number of function evaluations reached.", + ) - if exit_code == OptimizerExitCode.USER_ABORT: - return ServerStatus.stopped, "Optimization aborted." + case EverestExitCode.USER_ABORT: + return ServerStatus.stopped, "Optimization aborted." - if exit_code == OptimizerExitCode.TOO_FEW_REALIZATIONS: - status = ( - ServerStatus.stopped if shared_data[STOP_ENDPOINT] else ServerStatus.failed - ) - messages = _failed_realizations_messages(shared_data) - for msg in messages: - logging.getLogger(EVEREST).error(msg) - return status, "\n".join(messages) - - return ServerStatus.completed, "Optimization completed." + case EverestExitCode.TOO_FEW_REALIZATIONS: + status = ( + ServerStatus.stopped + if shared_data[STOP_ENDPOINT] + else ServerStatus.failed + ) + messages = _failed_realizations_messages(shared_data) + for msg in messages: + logging.getLogger(EVEREST).error(msg) + return status, "\n".join(messages) + case _: + return ServerStatus.completed, "Optimization completed." def _failed_realizations_messages(shared_data):