From 4ca8a667078fdba3af89c88bea040a5daf456746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Berland?= Date: Tue, 21 Nov 2023 13:54:44 +0100 Subject: [PATCH] Typing src/ert/job_queue --- .mypy.ini | 3 ++ src/ert/job_queue/driver.py | 37 +++++++++----------- src/ert/job_queue/queue.py | 16 +++++---- src/ert/job_queue/realization_state.py | 21 ++++------- tests/unit_tests/job_queue/test_job_queue.py | 15 ++++---- 5 files changed, 46 insertions(+), 46 deletions(-) diff --git a/.mypy.ini b/.mypy.ini index 3bffd3df154..f62ce7c6529 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -78,5 +78,8 @@ ignore_missing_imports = True [mypy-ruamel] ignore_missing_imports = True +[mypy-statemachine] +ignore_missing_imports = True + [mypy-ert.callbacks] ignore_errors = True diff --git a/src/ert/job_queue/driver.py b/src/ert/job_queue/driver.py index 92591828bda..bc9d759f6c8 100644 --- a/src/ert/job_queue/driver.py +++ b/src/ert/job_queue/driver.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from ert.config import QueueConfig - from ert.job_queue import QueueableRealization, RealizationState + from ert.job_queue import RealizationState class Driver(ABC): @@ -16,7 +16,7 @@ def __init__( self, options: Optional[List[Tuple[str, str]]] = None, ): - self._options = {} + self._options: Dict[str, str] = {} if options: for key, value in options: @@ -32,15 +32,15 @@ def has_option(self, option_key: str) -> bool: return option_key in self._options @abstractmethod - async def submit(self, realization: "RealizationState"): + async def submit(self, realization: "RealizationState") -> None: pass @abstractmethod - async def poll_statuses(self): + async def poll_statuses(self) -> None: pass @abstractmethod - async def kill(self, realization: "RealizationState"): + async def kill(self, realization: "RealizationState") -> None: pass @classmethod @@ -60,10 +60,10 @@ def __init__(self, queue_config: List[Tuple[str, str]]): self._currently_polling = False @property - def optionnames(self): + def optionnames(self) -> List[str]: return [] - async def submit(self, realization: "RealizationState"): + async def submit(self, realization: "RealizationState") -> None: """Submit and *actually (a)wait* for the process to finish.""" realization.accept() try: @@ -93,16 +93,16 @@ async def submit(self, realization: "RealizationState"): realization.runfail() # TODO: fetch stdout/stderr - async def poll_statuses(self): + async def poll_statuses(self) -> None: pass - async def kill(self, realization: "RealizationState"): + async def kill(self, realization: "RealizationState") -> None: self._processes[realization].kill() realization.verify_kill() class LSFDriver(Driver): - def __init__(self, queue_options): + def __init__(self, queue_options: Optional[List[Tuple[str, str]]]): super().__init__(queue_options) self._realstate_to_lsfid: Dict["RealizationState", str] = {} @@ -113,15 +113,14 @@ def __init__(self, queue_options): self._currently_polling = False - async def submit(self, realization: "RealizationState"): - submit_cmd = [ + async def submit(self, realization: "RealizationState") -> None: + submit_cmd: List[str] = [ "bsub", "-J", f"poly_{realization.realization.run_arg.iens}", - realization.realization.job_script, - realization.realization.run_arg.runpath, + str(realization.realization.job_script), + str(realization.realization.run_arg.runpath), ] - assert shutil.which(submit_cmd[0]) # does not propagate back.. process = await asyncio.create_subprocess_exec( *submit_cmd, stdout=asyncio.subprocess.PIPE, @@ -142,13 +141,11 @@ async def submit(self, realization: "RealizationState"): print(f"Submitted job {realization} and got LSF JOBID {lsf_id}") except Exception: # We should probably retry the submission, bsub stdout seems flaky. - print(f"ERROR: Could not parse lsf id from: {output}") + print(f"ERROR: Could not parse lsf id from: {output!r}") async def poll_statuses(self) -> None: if self._currently_polling: - # Don't repeat if we are called too often. - # So easy in async.. - return self._statuses + return self._currently_polling = True if not self._realstate_to_lsfid: @@ -198,6 +195,6 @@ async def poll_statuses(self) -> None: self._currently_polling = False - async def kill(self, realization): + async def kill(self, realization: "RealizationState") -> None: print(f"would like to kill {realization}") pass diff --git a/src/ert/job_queue/queue.py b/src/ert/job_queue/queue.py index 45801fbaa3b..74f13cd7f8a 100644 --- a/src/ert/job_queue/queue.py +++ b/src/ert/job_queue/queue.py @@ -119,10 +119,12 @@ def realization_state(self, iens: int) -> RealizationState: def count_realization_state(self, state: RealizationState) -> int: return sum(real.current_state == state for real in self._realizations) - async def run_done_callback(self, state: RealizationState): + async def run_done_callback(self, state: RealizationState) -> Optional[LoadStatus]: callback_status, status_msg = forward_model_ok(state.realization.run_arg) if callback_status == LoadStatus.LOAD_SUCCESSFUL: state.validate() + # todo: implement me + return None @property def stopped(self) -> bool: @@ -162,7 +164,9 @@ def queue_size(self) -> int: def _add_realization(self, realization: QueueableRealization) -> int: self._realizations.append( - RealizationState(self, realization, retries=self._queue_config.max_submit - 1) + RealizationState( + self, realization, retries=self._queue_config.max_submit - 1 + ) ) return len(self._realizations) - 1 @@ -311,9 +315,10 @@ async def execute( await self.driver.poll_statuses() for real in self._realizations: + if real.realization.max_runtime is None: + continue if ( - real.realization.max_runtime != None - and real.current_state == RealizationState.RUNNING + real.current_state == RealizationState.RUNNING and real.start_time and datetime.datetime.now() - real.start_time > datetime.timedelta(seconds=real.realization.max_runtime) @@ -353,7 +358,6 @@ async def execute( return EVTYPE_ENSEMBLE_STOPPED - def add_realization_from_run_arg( self, run_arg: "RunArg", @@ -412,7 +416,7 @@ def stop_long_running_realizations( sum(real.runtime for real in completed) / finished_realizations ) - for job in self.job_list: + for job in self.job_list: # type: ignore if job.runtime > LONG_RUNNING_FACTOR * average_runtime: job.stop() diff --git a/src/ert/job_queue/realization_state.py b/src/ert/job_queue/realization_state.py index f386f97046c..697e993a699 100644 --- a/src/ert/job_queue/realization_state.py +++ b/src/ert/job_queue/realization_state.py @@ -33,15 +33,8 @@ class QueueableRealization: # Aka "Job" or previously "JobQueueNode" max_runtime: Optional[int] = None callback_timeout: Optional[Callable[[int], None]] = None - def __hash__(self): - # Elevate iens up to two levels? Check if it can be removed from run_arg - return self.run_arg.iens - def __repr__(self): - return str(self.run_arg.iens) - - -class RealizationState(StateMachine): +class RealizationState(StateMachine): # type: ignore NOT_ACTIVE = State("NOT ACTIVE") WAITING = State("WAITING", initial=True) SUBMITTED = State("SUBMITTED") @@ -102,7 +95,7 @@ def __init__( donotgohere = UNKNOWN.to(STATUS_FAILURE) - def on_enter_state(self, target, event): + def on_enter_state(self, target: RealizationState) -> None: if self.jobqueue._changes_to_publish is None: return if target in ( @@ -116,21 +109,21 @@ def on_enter_state(self, target, event): change = {self.realization.run_arg.iens: target.id} asyncio.create_task(self.jobqueue._changes_to_publish.put(change)) - def on_enter_SUBMITTED(self): + def on_enter_SUBMITTED(self) -> None: asyncio.create_task(self.jobqueue.driver.submit(self)) - def on_enter_RUNNING(self): + def on_enter_RUNNING(self) -> None: self.start_time = datetime.datetime.now() - def on_enter_EXIT(self): + def on_enter_EXIT(self) -> None: if self.retries_left > 0: self.retry() self.retries_left -= 1 else: self.invalidate() - def on_enter_DONE(self): + def on_enter_DONE(self) -> None: asyncio.create_task(self.jobqueue.run_done_callback(self)) - def on_enter_DO_KILL(self): + def on_enter_DO_KILL(self) -> None: asyncio.create_task(self.jobqueue.driver.kill(self)) diff --git a/tests/unit_tests/job_queue/test_job_queue.py b/tests/unit_tests/job_queue/test_job_queue.py index fd4270d2bcb..32f15fa2dae 100644 --- a/tests/unit_tests/job_queue/test_job_queue.py +++ b/tests/unit_tests/job_queue/test_job_queue.py @@ -129,9 +129,12 @@ async def test_max_submit(tmpdir, monkeypatch, failing_script, max_submit_num): await job_queue.stop_jobs_async() await asyncio.gather(execute_task) + @pytest.mark.asyncio @pytest.mark.parametrize("max_submit_num", [1, 3]) -async def test_that_kill_queue_disregards_max_submit(tmpdir, max_submit_num, monkeypatch, simple_script): +async def test_that_kill_queue_disregards_max_submit( + tmpdir, max_submit_num, monkeypatch, simple_script +): monkeypatch.chdir(tmpdir) job_queue = create_local_queue(simple_script, max_submit=max_submit_num) await job_queue.stop_jobs_async() @@ -140,7 +143,11 @@ async def test_that_kill_queue_disregards_max_submit(tmpdir, max_submit_num, mon print(tmpdir) for iens in range(job_queue.queue_size): assert not Path(f"dummy_path_{iens}/STATUS").exists() - assert job_queue.count_realization_state(RealizationState.IS_KILLED) == job_queue.queue_size + assert ( + job_queue.count_realization_state(RealizationState.IS_KILLED) + == job_queue.queue_size + ) + @pytest.mark.asyncio @pytest.mark.timeout(5) @@ -231,9 +238,6 @@ def test_add_dispatch_info_cert_none(tmpdir, monkeypatch, simple_script): assert not (runpath / cert_file).exists() - - - @pytest.mark.skip(reason="Needs reimplementation") def test_stop_long_running(): """ @@ -283,7 +287,6 @@ def test_stop_long_running(): assert queue.snapshot()[i] == str(JobStatus.RUNNING) - @pytest.mark.usefixtures("use_tmpdir", "mock_fm_ok") @pytest.mark.skip(reason="Needs reimplementation") def test_num_cpu_submitted_correctly_lsf(tmpdir, simple_script):