Skip to content

Commit

Permalink
EverestRunModel: Introduce EverestExitCode
Browse files Browse the repository at this point in the history
  • Loading branch information
verveerpj committed Dec 18, 2024
1 parent 63e13b3 commit 891f8ca
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 41 deletions.
47 changes: 28 additions & 19 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@
from collections import defaultdict
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from enum import IntEnum
from pathlib import Path
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Literal,
Protocol,
)
from typing import TYPE_CHECKING, Any, Protocol

import numpy as np
import seba_sqlite.sqlite_storage
Expand Down Expand Up @@ -146,6 +142,14 @@ def from_seba_optimal_result(
)


class EverestExitCode(IntEnum):
COMPLETED = 1
TOO_FEW_REALIZATIONS = 2
MAX_FUNCTIONS_REACHED = 3
MAX_BATCH_NUM_REACHED = 4
USER_ABORT = 5


class EverestRunModel(BaseRunModel):
def __init__(
self,
Expand Down Expand Up @@ -173,10 +177,7 @@ def __init__(
self._opt_callback = optimization_callback
self._fm_errors: dict[int, dict[str, Any]] = {}
self._result: OptimalResult | None = None
self._exit_code: Literal["max_batch_num_reached"] | OptimizerExitCode | None = (
None
)
self._max_batch_num_reached = False
self._exit_code: EverestExitCode | None = None
self._simulator_cache = (
SimulatorCache()
if (
Expand Down Expand Up @@ -248,11 +249,21 @@ def run_experiment(
seba_storage.get_optimal_result() # type: ignore
)

self._exit_code = (
"max_batch_num_reached"
if self._max_batch_num_reached
else optimizer_exit_code
)
if self._exit_code is None:
self._exit_code = self._get_everest_exit_code(optimizer_exit_code)

def _get_everest_exit_code(
self, optimizer_exit_code: OptimizerExitCode
) -> EverestExitCode:
match optimizer_exit_code:
case OptimizerExitCode.MAX_FUNCTIONS_REACHED:
return EverestExitCode.MAX_FUNCTIONS_REACHED
case OptimizerExitCode.USER_ABORT:
return EverestExitCode.USER_ABORT
case OptimizerExitCode.TOO_FEW_REALIZATIONS:
return EverestExitCode.TOO_FEW_REALIZATIONS
case _:
return EverestExitCode.COMPLETED

def check_if_runpath_exists(self) -> bool:
return (
Expand Down Expand Up @@ -321,7 +332,7 @@ def _on_before_forward_model_evaluation(
and self._everest_config.optimization.max_batch_num is not None
and (self._batch_id >= self._everest_config.optimization.max_batch_num)
):
self._max_batch_num_reached = True
self._exit_code = EverestExitCode.MAX_BATCH_NUM_REACHED
logging.getLogger(EVEREST).info("Maximum number of batches reached")
optimizer.abort_optimization()
if (
Expand Down Expand Up @@ -396,9 +407,7 @@ def description(cls) -> str:
return "Run batches "

@property
def exit_code(
self,
) -> Literal["max_batch_num_reached"] | OptimizerExitCode | None:
def exit_code(self) -> EverestExitCode | None:
return self._exit_code

@property
Expand Down
43 changes: 24 additions & 19 deletions src/everest/detached/jobs/everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
HTTPBasic,
HTTPBasicCredentials,
)
from ropt.enums import OptimizerExitCode

from ert.config import QueueSystem
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.run_models.everest_run_model import EverestRunModel
from ert.run_models.everest_run_model import EverestExitCode, EverestRunModel
from everest import export_to_csv, export_with_progress
from everest.config import EverestConfig, ServerConfig
from everest.detached import ServerStatus, get_opt_status, update_everserver_status
Expand Down Expand Up @@ -373,25 +372,31 @@ def main():


def _get_optimization_status(exit_code, shared_data):
if exit_code == "max_batch_num_reached":
return ServerStatus.completed, "Maximum number of batches reached."

if exit_code == OptimizerExitCode.MAX_FUNCTIONS_REACHED:
return ServerStatus.completed, "Maximum number of function evaluations reached."
match exit_code:
case EverestExitCode.MAX_BATCH_NUM_REACHED:
return ServerStatus.completed, "Maximum number of batches reached."

case EverestExitCode.MAX_FUNCTIONS_REACHED:
return (
ServerStatus.completed,
"Maximum number of function evaluations reached.",
)

if exit_code == OptimizerExitCode.USER_ABORT:
return ServerStatus.stopped, "Optimization aborted."
case EverestExitCode.USER_ABORT:
return ServerStatus.stopped, "Optimization aborted."

if exit_code == OptimizerExitCode.TOO_FEW_REALIZATIONS:
status = (
ServerStatus.stopped if shared_data[STOP_ENDPOINT] else ServerStatus.failed
)
messages = _failed_realizations_messages(shared_data)
for msg in messages:
logging.getLogger(EVEREST).error(msg)
return status, "\n".join(messages)

return ServerStatus.completed, "Optimization completed."
case EverestExitCode.TOO_FEW_REALIZATIONS:
status = (
ServerStatus.stopped
if shared_data[STOP_ENDPOINT]
else ServerStatus.failed
)
messages = _failed_realizations_messages(shared_data)
for msg in messages:
logging.getLogger(EVEREST).error(msg)
return status, "\n".join(messages)
case _:
return ServerStatus.completed, "Optimization completed."


def _failed_realizations_messages(shared_data):
Expand Down
6 changes: 3 additions & 3 deletions tests/everest/test_everserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from pathlib import Path
from unittest.mock import patch

from ropt.enums import OptimizerExitCode
from seba_sqlite.snapshot import SebaSnapshot

from ert.run_models.everest_run_model import EverestExitCode
from everest.config import EverestConfig, ServerConfig
from everest.detached import ServerStatus, everserver_status
from everest.detached.jobs import everserver
Expand All @@ -33,8 +33,8 @@ def fail_optimization(self, from_ropt=False):
# shared_data (see set_shared_status() below).
self._sim_callback(None)
if from_ropt:
self._exit_code = OptimizerExitCode.TOO_FEW_REALIZATIONS
return OptimizerExitCode.TOO_FEW_REALIZATIONS
self._exit_code = EverestExitCode.TOO_FEW_REALIZATIONS
return EverestExitCode.TOO_FEW_REALIZATIONS

raise Exception("Failed optimization")

Expand Down

0 comments on commit 891f8ca

Please sign in to comment.