From 891f8ca30a4c252decc5bde145cabd43d0adc703 Mon Sep 17 00:00:00 2001 From: Peter Verveer Date: Wed, 18 Dec 2024 15:18:24 +0000 Subject: [PATCH] EverestRunModel: Introduce EverestExitCode --- src/ert/run_models/everest_run_model.py | 47 +++++++++++++++---------- src/everest/detached/jobs/everserver.py | 43 ++++++++++++---------- tests/everest/test_everserver.py | 6 ++-- 3 files changed, 55 insertions(+), 41 deletions(-) diff --git a/src/ert/run_models/everest_run_model.py b/src/ert/run_models/everest_run_model.py index d4e014eb90b..76a0a522b96 100644 --- a/src/ert/run_models/everest_run_model.py +++ b/src/ert/run_models/everest_run_model.py @@ -10,14 +10,10 @@ from collections import defaultdict from collections.abc import Callable, Mapping from dataclasses import dataclass +from enum import IntEnum from pathlib import Path from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Literal, - Protocol, -) +from typing import TYPE_CHECKING, Any, Protocol import numpy as np import seba_sqlite.sqlite_storage @@ -146,6 +142,14 @@ def from_seba_optimal_result( ) +class EverestExitCode(IntEnum): + COMPLETED = 1 + TOO_FEW_REALIZATIONS = 2 + MAX_FUNCTIONS_REACHED = 3 + MAX_BATCH_NUM_REACHED = 4 + USER_ABORT = 5 + + class EverestRunModel(BaseRunModel): def __init__( self, @@ -173,10 +177,7 @@ def __init__( self._opt_callback = optimization_callback self._fm_errors: dict[int, dict[str, Any]] = {} self._result: OptimalResult | None = None - self._exit_code: Literal["max_batch_num_reached"] | OptimizerExitCode | None = ( - None - ) - self._max_batch_num_reached = False + self._exit_code: EverestExitCode | None = None self._simulator_cache = ( SimulatorCache() if ( @@ -248,11 +249,21 @@ def run_experiment( seba_storage.get_optimal_result() # type: ignore ) - self._exit_code = ( - "max_batch_num_reached" - if self._max_batch_num_reached - else optimizer_exit_code - ) + if self._exit_code is None: + self._exit_code = self._get_everest_exit_code(optimizer_exit_code) + + def _get_everest_exit_code( + self, optimizer_exit_code: OptimizerExitCode + ) -> EverestExitCode: + match optimizer_exit_code: + case OptimizerExitCode.MAX_FUNCTIONS_REACHED: + return EverestExitCode.MAX_FUNCTIONS_REACHED + case OptimizerExitCode.USER_ABORT: + return EverestExitCode.USER_ABORT + case OptimizerExitCode.TOO_FEW_REALIZATIONS: + return EverestExitCode.TOO_FEW_REALIZATIONS + case _: + return EverestExitCode.COMPLETED def check_if_runpath_exists(self) -> bool: return ( @@ -321,7 +332,7 @@ def _on_before_forward_model_evaluation( and self._everest_config.optimization.max_batch_num is not None and (self._batch_id >= self._everest_config.optimization.max_batch_num) ): - self._max_batch_num_reached = True + self._exit_code = EverestExitCode.MAX_BATCH_NUM_REACHED logging.getLogger(EVEREST).info("Maximum number of batches reached") optimizer.abort_optimization() if ( @@ -396,9 +407,7 @@ def description(cls) -> str: return "Run batches " @property - def exit_code( - self, - ) -> Literal["max_batch_num_reached"] | OptimizerExitCode | None: + def exit_code(self) -> EverestExitCode | None: return self._exit_code @property diff --git a/src/everest/detached/jobs/everserver.py b/src/everest/detached/jobs/everserver.py index 936ccedf236..02ae28a1172 100755 --- a/src/everest/detached/jobs/everserver.py +++ b/src/everest/detached/jobs/everserver.py @@ -30,11 +30,10 @@ HTTPBasic, HTTPBasicCredentials, ) -from ropt.enums import OptimizerExitCode from ert.config import QueueSystem from ert.ensemble_evaluator import EvaluatorServerConfig -from ert.run_models.everest_run_model import EverestRunModel +from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel from everest import export_to_csv, export_with_progress from everest.config import EverestConfig, ServerConfig from everest.detached import ServerStatus, get_opt_status, update_everserver_status @@ -373,25 +372,31 @@ def main(): def _get_optimization_status(exit_code, shared_data): - if exit_code == "max_batch_num_reached": - return ServerStatus.completed, "Maximum number of batches reached." - - if exit_code == OptimizerExitCode.MAX_FUNCTIONS_REACHED: - return ServerStatus.completed, "Maximum number of function evaluations reached." + match exit_code: + case EverestExitCode.MAX_BATCH_NUM_REACHED: + return ServerStatus.completed, "Maximum number of batches reached." + + case EverestExitCode.MAX_FUNCTIONS_REACHED: + return ( + ServerStatus.completed, + "Maximum number of function evaluations reached.", + ) - if exit_code == OptimizerExitCode.USER_ABORT: - return ServerStatus.stopped, "Optimization aborted." + case EverestExitCode.USER_ABORT: + return ServerStatus.stopped, "Optimization aborted." - if exit_code == OptimizerExitCode.TOO_FEW_REALIZATIONS: - status = ( - ServerStatus.stopped if shared_data[STOP_ENDPOINT] else ServerStatus.failed - ) - messages = _failed_realizations_messages(shared_data) - for msg in messages: - logging.getLogger(EVEREST).error(msg) - return status, "\n".join(messages) - - return ServerStatus.completed, "Optimization completed." + case EverestExitCode.TOO_FEW_REALIZATIONS: + status = ( + ServerStatus.stopped + if shared_data[STOP_ENDPOINT] + else ServerStatus.failed + ) + messages = _failed_realizations_messages(shared_data) + for msg in messages: + logging.getLogger(EVEREST).error(msg) + return status, "\n".join(messages) + case _: + return ServerStatus.completed, "Optimization completed." def _failed_realizations_messages(shared_data): diff --git a/tests/everest/test_everserver.py b/tests/everest/test_everserver.py index f1e4f94d5a9..84224de944e 100644 --- a/tests/everest/test_everserver.py +++ b/tests/everest/test_everserver.py @@ -5,9 +5,9 @@ from pathlib import Path from unittest.mock import patch -from ropt.enums import OptimizerExitCode from seba_sqlite.snapshot import SebaSnapshot +from ert.run_models.everest_run_model import EverestExitCode from everest.config import EverestConfig, ServerConfig from everest.detached import ServerStatus, everserver_status from everest.detached.jobs import everserver @@ -33,8 +33,8 @@ def fail_optimization(self, from_ropt=False): # shared_data (see set_shared_status() below). self._sim_callback(None) if from_ropt: - self._exit_code = OptimizerExitCode.TOO_FEW_REALIZATIONS - return OptimizerExitCode.TOO_FEW_REALIZATIONS + self._exit_code = EverestExitCode.TOO_FEW_REALIZATIONS + return EverestExitCode.TOO_FEW_REALIZATIONS raise Exception("Failed optimization")