Skip to content

Commit

Permalink
Update exit_code code
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Dec 18, 2024
1 parent 200051c commit d07144e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 32 deletions.
38 changes: 25 additions & 13 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import datetime
from enum import IntEnum
import functools
import json
import logging
Expand Down Expand Up @@ -72,6 +73,14 @@ class OptimizerCallback(Protocol):
def __call__(self) -> str | None: ...


class EverestExitCode(IntEnum):
COMPLETED = 1
TOO_FEW_REALIZATIONS = 2
MAX_FUNCTIONS_REACHED = 3
MAX_BATCH_NUM_REACHED = 4
USER_ABORT = 5


@dataclass
class OptimalResult:
batch: int
Expand Down Expand Up @@ -120,10 +129,7 @@ def __init__(
self._fm_errors: dict[int, dict[str, Any]] = {}
self._display_all_jobs = display_all_jobs
self._result: OptimalResult | None = None
self._exit_code: Literal["max_batch_num_reached"] | OptimizerExitCode | None = (
None
)
self._max_batch_num_reached = False
self._exit_code: EverestExitCode | None = None
self._evaluator_cache: _EvaluatorCache | None = None
if (
everest_config.simulator is not None
Expand Down Expand Up @@ -180,9 +186,7 @@ def description(cls) -> str:
return "Run batches "

@property
def exit_code(
self,
) -> Literal["max_batch_num_reached"] | OptimizerExitCode | None:
def exit_code(self) -> EverestExitCode | None:
return self._exit_code

@property
Expand Down Expand Up @@ -226,11 +230,8 @@ def run_experiment(
seba_storage.get_optimal_result() # type: ignore
)

self._exit_code = (
"max_batch_num_reached"
if self._max_batch_num_reached
else optimizer_exit_code
)
if self._exit_code is None:
self._exit_code = self._get_exit_code(optimizer_exit_code)

def _create_optimizer(self) -> BasicOptimizer:
RESULT_COLUMNS = {
Expand Down Expand Up @@ -332,6 +333,17 @@ def _create_optimizer(self) -> BasicOptimizer:

return optimizer

def _get_exit_code(self, optimizer_exit_code: OptimizerExitCode) -> EverestExitCode:
match optimizer_exit_code:
case OptimizerExitCode.MAX_FUNCTIONS_REACHED:
return EverestExitCode.MAX_FUNCTIONS_REACHED
case OptimizerExitCode.USER_ABORT:
return EverestExitCode.USER_ABORT
case OptimizerExitCode.TOO_FEW_REALIZATIONS:
return EverestExitCode.TOO_FEW_REALIZATIONS
case _:
return EverestExitCode.COMPLETED

def _on_before_forward_model_evaluation(
self, _: OptimizerEvent, optimizer: BasicOptimizer
) -> None:
Expand All @@ -342,7 +354,7 @@ def _on_before_forward_model_evaluation(
and self.everest_config.optimization.max_batch_num is not None
and (self.batch_id >= self.everest_config.optimization.max_batch_num)
):
self._max_batch_num_reached = True
self._exit_code = EverestExitCode.MAX_BATCH_NUM_REACHED
logging.getLogger(EVEREST).info("Maximum number of batches reached")
optimizer.abort_optimization()
if (
Expand Down
43 changes: 24 additions & 19 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
HTTPBasic,
HTTPBasicCredentials,
)
from ropt.enums import OptimizerExitCode

from ert.config import QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_models.everest_run_model import EverestRunModel
from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel
from everest import export_to_csv, export_with_progress
from everest.config import EverestConfig, ServerConfig
from everest.detached import ServerStatus, get_opt_status, update_everserver_status
Expand Down Expand Up @@ -373,25 +372,31 @@ def main():


def _get_optimization_status(exit_code, shared_data):
if exit_code == "max_batch_num_reached":
return ServerStatus.completed, "Maximum number of batches reached."

if exit_code == OptimizerExitCode.MAX_FUNCTIONS_REACHED:
return ServerStatus.completed, "Maximum number of function evaluations reached."
match exit_code:
case EverestExitCode.MAX_BATCH_NUM_REACHED:
return ServerStatus.completed, "Maximum number of batches reached."

case EverestExitCode.MAX_FUNCTIONS_REACHED:
return (
ServerStatus.completed,
"Maximum number of function evaluations reached.",
)

if exit_code == OptimizerExitCode.USER_ABORT:
return ServerStatus.stopped, "Optimization aborted."
case EverestExitCode.USER_ABORT:
return ServerStatus.stopped, "Optimization aborted."

if exit_code == OptimizerExitCode.TOO_FEW_REALIZATIONS:
status = (
ServerStatus.stopped if shared_data[STOP_ENDPOINT] else ServerStatus.failed
)
messages = _failed_realizations_messages(shared_data)
for msg in messages:
logging.getLogger(EVEREST).error(msg)
return status, "\n".join(messages)

return ServerStatus.completed, "Optimization completed."
case EverestExitCode.TOO_FEW_REALIZATIONS:
status = (
ServerStatus.stopped
if shared_data[STOP_ENDPOINT]
else ServerStatus.failed
)
messages = _failed_realizations_messages(shared_data)
for msg in messages:
logging.getLogger(EVEREST).error(msg)
return status, "\n".join(messages)
case _:
return ServerStatus.completed, "Optimization completed."


def _failed_realizations_messages(shared_data):
Expand Down

0 comments on commit d07144e

Please sign in to comment.