diff --git a/src/ert/ensemble_evaluator/_builder/_legacy.py b/src/ert/ensemble_evaluator/_builder/_legacy.py index 4ae340e3d7f..c8e891bba91 100644 --- a/src/ert/ensemble_evaluator/_builder/_legacy.py +++ b/src/ert/ensemble_evaluator/_builder/_legacy.py @@ -5,7 +5,17 @@ import threading import uuid from functools import partial, partialmethod -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, +) from cloudevents.http.event import CloudEvent @@ -183,7 +193,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches # something is long running, the evaluator will know and should send # commands to the task in order to have it killed/retried. # See https://github.com/equinor/ert/issues/1229 - queue_evaluators = None + queue_evaluators: Optional[Sequence[Callable[[], None]]] = None if ( self._analysis_config.stop_long_running and self._analysis_config.minimum_required_realizations > 0 @@ -208,7 +218,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches result: str = await self._job_queue.execute( queue_evaluators, - ) # type: ignore + ) print(result) except Exception as exc: print(exc) diff --git a/src/ert/job_queue/driver.py b/src/ert/job_queue/driver.py index 37eec6aec8c..e7f69e272c5 100644 --- a/src/ert/job_queue/driver.py +++ b/src/ert/job_queue/driver.py @@ -12,6 +12,7 @@ Dict, List, Optional, + Sequence, Tuple, ) @@ -28,22 +29,9 @@ class Driver(ABC): def __init__( self, - options: Optional[List[Tuple[str, str]]] = None, + options: Optional[Sequence[Tuple[str, str]]] = None, ): - self._options: Dict[str, str] = {} - - if options: - for key, value in options: - self.set_option(key, value) - - def set_option(self, option: str, value: str) -> None: - self._options.update({option: value}) - - def get_option(self, option_key: str) -> str: - return self._options[option_key] - - def has_option(self, option_key: str) -> bool: - return option_key in self._options + self.options: Dict[str, str] = dict(options or []) @abstractmethod async def submit(self, realization: "RealizationState") -> None: @@ -152,7 +140,7 @@ def __init__(self, queue_options: Optional[List[Tuple[str, str]]]): self._currently_polling = False async def run_with_retries( - self, func: Callable[[Any], Awaitable[Any]], error_msg: str = "" + self, func: Callable[[], Awaitable[Any]], error_msg: str = "" ) -> None: current_attempt = 0 while current_attempt < self._max_attempt: @@ -169,12 +157,10 @@ async def run_with_retries( async def submit(self, realization: "RealizationState") -> None: submit_cmd = self.build_submit_cmd( - [ - "-J", - f"poly_{realization.realization.run_arg.iens}", - str(realization.realization.job_script), - str(realization.realization.run_arg.runpath), - ] + "-J", + f"poly_{realization.realization.run_arg.iens}", + str(realization.realization.job_script), + str(realization.realization.run_arg.runpath), ) await self.run_with_retries( lambda: self._submit(submit_cmd, realization=realization), @@ -203,17 +189,15 @@ async def _submit( logger.info(f"Submitted job {realization} and got LSF JOBID {lsf_id}") return True - def build_submit_cmd(self, args: List[str]) -> List[str]: - submit_cmd = [ - self.get_option("BSUB_CMD") if self.has_option("BSUB_CMD") else "bsub" - ] - if self.has_option("LSF_QUEUE"): - submit_cmd += ["-q", self.get_option("LSF_QUEUE")] + def build_submit_cmd(self, *args: str) -> List[str]: + submit_cmd = [self.options.get("BSUB_CMD", "bsub")] + if (lsf_queue := self.options.get("LSF_QUEUE")) is not None: + submit_cmd += ["-q", lsf_queue] - return submit_cmd + args + return [*submit_cmd, *args] async def run_shell_command( - self, command_to_run: List[str], command_name: str="" + self, command_to_run: List[str], command_name: str = "" ) -> Optional[Tuple[asyncio.subprocess.Process, bytes, bytes]]: process = await asyncio.create_subprocess_exec( *command_to_run, @@ -243,10 +227,9 @@ async def poll_statuses(self) -> None: return poll_cmd = [ - str(self.get_option("BJOBS_CMD")) - if self.has_option("BJOBS_CMD") - else "bjobs" - ] + list(self._realstate_to_lsfid.values()) + self.options.get("BJOBS_CMD", "bjobs"), + *self._realstate_to_lsfid.values(), + ] try: await self.run_with_retries(lambda: self._poll_statuses(poll_cmd)) # suppress runtime error @@ -309,7 +292,7 @@ async def kill(self, realization: "RealizationState") -> None: lsf_job_id = self._realstate_to_lsfid[realization] logger.debug(f"Attempting to kill {lsf_job_id=}") kill_cmd = [ - self.get_option("BKILL_CMD") if self.has_option("BKILL_CMD") else "bkill", + self.options.get("BKILL_CMD", "bkill"), lsf_job_id, ] await self.run_with_retries( diff --git a/src/ert/job_queue/queue.py b/src/ert/job_queue/queue.py index 11ce639ffee..4b7475c36c8 100644 --- a/src/ert/job_queue/queue.py +++ b/src/ert/job_queue/queue.py @@ -11,7 +11,7 @@ import ssl from collections import deque from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Union from cloudevents.conversion import to_json from cloudevents.http import CloudEvent @@ -186,8 +186,8 @@ def _add_realization(self, realization: QueueableRealization) -> int: def max_running(self) -> int: max_running = 0 - if self.driver.has_option("MAX_RUNNING"): - max_running = int(self.driver.get_option("MAX_RUNNING")) + if (value := self.driver.options.get("MAX_RUNNING")) is not None: + max_running = int(value) if max_running == 0: return len(self._realizations) return max_running @@ -308,7 +308,7 @@ async def _realization_statechange_publisher(self) -> None: async def execute( self, - evaluators: Optional[List[Callable[..., Any]]] = None, + evaluators: Optional[Sequence[Callable[[], None]]] = None, ) -> str: if evaluators is None: evaluators = [] diff --git a/src/ert/job_queue/realization_state.py b/src/ert/job_queue/realization_state.py index 23e3b2ef5d7..38daa64469a 100644 --- a/src/ert/job_queue/realization_state.py +++ b/src/ert/job_queue/realization_state.py @@ -140,13 +140,10 @@ def on_enter_EXIT(self) -> None: ) if exit_file_path.exists(): exit_file = etree.parse(exit_file_path) - failed_job = exit_file.find("job").text - error_reason = exit_file.find("reason").text - stderr_capture = exit_file.find("stderr").text - - stderr_file = "" - if stderr_file_node := exit_file.find("stderr_file"): - stderr_file = stderr_file_node.text + failed_job = exit_file.findtext("job", default="") + error_reason = exit_file.findtext("reason", default="") + stderr_capture = exit_file.findtext("stderr", default="") + stderr_file = exit_file.findtext("stderr_file", default="") logger.error( f"job {failed_job} failed with: '{error_reason}'\n" diff --git a/src/ert/simulator/simulation_context.py b/src/ert/simulator/simulation_context.py index 1a72689e913..24073249472 100644 --- a/src/ert/simulator/simulation_context.py +++ b/src/ert/simulator/simulation_context.py @@ -4,7 +4,7 @@ from functools import partial from threading import Thread from time import sleep -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Tuple import numpy as np @@ -61,7 +61,7 @@ def _run_forward_model( ert.ert_config.preferred_num_cpu, ) - queue_evaluators = None + queue_evaluators: Optional[Sequence[Callable[[], None]]] = None if ( ert.ert_config.analysis_config.stop_long_running and ert.ert_config.analysis_config.minimum_required_realizations > 0 @@ -73,7 +73,7 @@ def _run_forward_model( ) ] - asyncio.run(job_queue.execute(evaluators=queue_evaluators)) # type: ignore + asyncio.run(job_queue.execute(queue_evaluators)) run_context.sim_fs.sync() diff --git a/tests/unit_tests/job_queue/_test_driver.py b/tests/unit_tests/job_queue/_test_driver.py index 7ac008ea22f..c113cff4a12 100644 --- a/tests/unit_tests/job_queue/_test_driver.py +++ b/tests/unit_tests/job_queue/_test_driver.py @@ -6,31 +6,6 @@ from ert.job_queue import Driver -@pytest.mark.xfail(reason="Needs reimplementation") -def test_set_and_unset_option(): - queue_config = QueueConfig( - job_script="script.sh", - queue_system=QueueSystem.LOCAL, - max_submit=2, - queue_options={ - QueueSystem.LOCAL: [ - ("MAX_RUNNING", "50"), - ("MAX_RUNNING", ""), - ] - }, - ) - driver = Driver.create_driver(queue_config) - assert driver.get_option("MAX_RUNNING") == "0" - assert driver.set_option("MAX_RUNNING", "42") - assert driver.get_option("MAX_RUNNING") == "42" - driver.set_option("MAX_RUNNING", "") - assert driver.get_option("MAX_RUNNING") == "0" - driver.set_option("MAX_RUNNING", "100") - assert driver.get_option("MAX_RUNNING") == "100" - driver.set_option("MAX_RUNNING", "0") - assert driver.get_option("MAX_RUNNING") == "0" - - @pytest.mark.xfail(reason="Needs reimplementation") def test_get_driver_name(): queue_config = QueueConfig(queue_system=QueueSystem.LOCAL) @@ -61,8 +36,6 @@ def test_get_slurm_queue_config(): assert queue_config.queue_system == QueueSystem.SLURM driver = Driver.create_driver(queue_config) - assert driver.get_option("SBATCH") == "/path/to/sbatch" - assert driver.get_option("SCONTROL") == "scontrol" - driver.set_option("SCONTROL", "") - assert driver.get_option("SCONTROL") == "" + assert driver.options["SBATCH"] == "/path/to/sbatch" + assert driver.options["SCONTROL"] == "scontrol" assert driver.name == "SLURM"