From 3d12038f8178dc45047db7a0430a520f7d488dfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Berland?= Date: Mon, 27 Nov 2023 06:50:23 +0100 Subject: [PATCH] Fix some typing --- src/ert/job_queue/driver.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/ert/job_queue/driver.py b/src/ert/job_queue/driver.py index 5e7c6ecaa2a..37eec6aec8c 100644 --- a/src/ert/job_queue/driver.py +++ b/src/ert/job_queue/driver.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, + Any, Awaitable, Callable, Dict, @@ -143,8 +144,7 @@ def __init__(self, queue_options: Optional[List[Tuple[str, str]]]): self._realstate_to_lsfid: Dict["RealizationState", str] = {} self._lsfid_to_realstate: Dict[str, "RealizationState"] = {} - self._max_attempt = 100 - self._statuses = {} + self._max_attempt: int = 100 self._submit_processes: Dict[ "RealizationState", "asyncio.subprocess.Process" ] = {} @@ -152,8 +152,8 @@ def __init__(self, queue_options: Optional[List[Tuple[str, str]]]): self._currently_polling = False async def run_with_retries( - self, func: Callable[[None], Awaitable], error_msg: str = "" - ): + self, func: Callable[[Any], Awaitable[Any]], error_msg: str = "" + ) -> None: current_attempt = 0 while current_attempt < self._max_attempt: current_attempt += 1 @@ -168,11 +168,13 @@ async def run_with_retries( raise RuntimeError(error_msg) async def submit(self, realization: "RealizationState") -> None: - submit_cmd = self.parse_submit_cmd( - "-J", - f"poly_{realization.realization.run_arg.iens}", - str(realization.realization.job_script), - str(realization.realization.run_arg.runpath), + submit_cmd = self.build_submit_cmd( + [ + "-J", + f"poly_{realization.realization.run_arg.iens}", + str(realization.realization.job_script), + str(realization.realization.run_arg.runpath), + ] ) await self.run_with_retries( lambda: self._submit(submit_cmd, realization=realization), @@ -180,7 +182,7 @@ async def submit(self, realization: "RealizationState") -> None: ) async def _submit( - self, submit_command: list[str], realization: "RealizationState" + self, submit_command: List[str], realization: "RealizationState" ) -> bool: result = await self.run_shell_command(submit_command, command_name="bsub") if not result: @@ -201,18 +203,18 @@ async def _submit( logger.info(f"Submitted job {realization} and got LSF JOBID {lsf_id}") return True - def parse_submit_cmd(self, *additional_parameters) -> List[str]: + def build_submit_cmd(self, args: List[str]) -> List[str]: submit_cmd = [ self.get_option("BSUB_CMD") if self.has_option("BSUB_CMD") else "bsub" ] if self.has_option("LSF_QUEUE"): submit_cmd += ["-q", self.get_option("LSF_QUEUE")] - return submit_cmd + list(additional_parameters) + return submit_cmd + args async def run_shell_command( - self, command_to_run: list[str], command_name="" - ) -> (asyncio.subprocess.Process, str, str): + self, command_to_run: List[str], command_name: str="" + ) -> Optional[Tuple[asyncio.subprocess.Process, bytes, bytes]]: process = await asyncio.create_subprocess_exec( *command_to_run, stdout=asyncio.subprocess.PIPE, @@ -254,7 +256,7 @@ async def poll_statuses(self) -> None: # raise this value error as runtime error raise RuntimeError from e - async def _poll_statuses(self, poll_cmd: str) -> bool: + async def _poll_statuses(self, poll_cmd: List[str]) -> bool: self._currently_polling = True result = await self.run_shell_command(poll_cmd, command_name="bjobs") @@ -316,10 +318,11 @@ async def kill(self, realization: "RealizationState") -> None: ) async def _kill( - self, kill_cmd, realization: "RealizationState", lsf_job_id: int + self, kill_cmd: List[str], realization: "RealizationState", lsf_job_id: str ) -> bool: result = await self.run_shell_command(kill_cmd, "bkill") if result is None: return False realization.verify_kill() logger.info(f"Successfully killed job {lsf_job_id}") + return True