diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index 662b1b8f28e..be26b443952 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -2,7 +2,7 @@ import asyncio from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Optional from cloudevents.conversion import to_json from cloudevents.http import CloudEvent @@ -47,12 +47,18 @@ class Job: (LSF, PBS, SLURM, etc.) """ - def __init__(self, scheduler: Scheduler, real: Realization) -> None: + def __init__( + self, + scheduler: Scheduler, + real: Realization, + callback_timeout: Optional[Callable[[int], None]] = None, + ) -> None: self.real = real self.started = asyncio.Event() self.returncode: asyncio.Future[int] = asyncio.Future() self.aborted = asyncio.Event() self._scheduler = scheduler + self._callback_timeout: Optional[Callable[[int], None]] = callback_timeout @property def iens(self) -> int: @@ -67,6 +73,7 @@ async def __call__( ) -> None: await start.wait() await sem.acquire() + cancel_task: Optional[asyncio.Task[None]] = None try: await self._send(State.SUBMITTING) @@ -78,6 +85,8 @@ async def __call__( await self.started.wait() await self._send(State.RUNNING) + if self.real.max_runtime is not None and self.real.max_runtime > 0: + cancel_task = asyncio.create_task(self._cancel_from_timeout()) returncode = await self.returncode if ( returncode == 0 @@ -89,14 +98,25 @@ async def __call__( await self._send(State.FAILED) except asyncio.CancelledError: - await self._send(State.ABORTING) - await self.driver.kill(self.iens) - - await self.aborted.wait() - await self._send(State.ABORTED) + await self._cancel() finally: + if cancel_task and not cancel_task.done(): + cancel_task.cancel() sem.release() + async def _cancel_from_timeout(self) -> None: + assert self.real.max_runtime is not None + await asyncio.sleep(self.real.max_runtime) + if self._callback_timeout is not None: + self._callback_timeout(self.real.iens) + self.returncode.cancel() # Triggers CancelledError + + async def _cancel(self) -> None: + await self._send(State.ABORTING) + await self.driver.kill(self.iens) + await self.aborted.wait() + await self._send(State.ABORTED) + async def _send(self, state: State) -> None: status = STATE_TO_LEGACY[state] event = CloudEvent( diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 82307222afa..219ecbaaf3d 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -63,9 +63,11 @@ async def ainit(self) -> None: self._events = asyncio.Queue() def add_realization( - self, real: Realization, callback_timeout: Callable[[int], None] + self, + real: Realization, + callback_timeout: Optional[Callable[[int], None]] = None, ) -> None: - self._jobs[real.iens] = Job(self, real) + self._jobs[real.iens] = Job(self, real, callback_timeout=callback_timeout) def kill_all_jobs(self) -> None: for task in self._tasks.values(): diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index d9813b5e3ee..9e526e137a0 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -92,7 +92,7 @@ async def test_cancel(tmp_path: Path, realization): realization.forward_models = [step] sch = scheduler.Scheduler() - sch.add_realization(realization, callback_timeout=lambda _: None) + sch.add_realization(realization) create_jobs_json(tmp_path, [step]) sch.add_dispatch_information_to_jobs_file() @@ -108,3 +108,25 @@ async def test_cancel(tmp_path: Path, realization): assert (tmp_path / "a").exists() assert not (tmp_path / "b").exists() + + +async def test_max_runtime_with_callback_timeout(tmp_path: Path, realization): + step = create_bash_step("touch a; sleep 10; touch b") + realization.forward_models = [step] + realization.max_runtime = 1 + + def mocked_callback(iens: int): + Path("from_callback").write_text(str(iens), encoding="utf-8") + + sch = scheduler.Scheduler() + sch.add_realization(realization, callback_timeout=mocked_callback) + + create_jobs_json(tmp_path, [step]) + sch.add_dispatch_information_to_jobs_file() + + scheduler_task = asyncio.create_task(sch.execute()) + await scheduler_task + + assert (tmp_path / "a").exists() + assert not (tmp_path / "b").exists() + assert Path("from_callback").read_text(encoding="utf-8") == str(realization.iens)