diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index 174bff74f29..d7c75f169af 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import TYPE_CHECKING, Callable, Optional, TypeVar +from typing import TYPE_CHECKING, Callable, Optional from cloudevents.conversion import to_json from cloudevents.http import CloudEvent @@ -20,7 +20,6 @@ from ert.scheduler.scheduler import Scheduler logger = logging.getLogger(__name__) -_T = TypeVar("_T") class State(str, Enum): @@ -75,6 +74,7 @@ def driver(self) -> Driver: async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: await sem.acquire() + cancel_task: Optional[asyncio.Task[None]] = None try: await self._send(State.SUBMITTING) @@ -86,31 +86,11 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: await self.started.wait() await self._send(State.RUNNING) - - async def threadsafer_future_wait(future: asyncio.Future[_T]) -> _T: - while not future.done(): - await asyncio.sleep(0.01) - returncode = await future - return returncode - - try: - returncode = await asyncio.wait_for( - threadsafer_future_wait(self.returncode), - timeout=self.real.max_runtime, - ) - except asyncio.TimeoutError: - assert self.real.max_runtime is not None - await asyncio.sleep(self.real.max_runtime) - timeout_event = CloudEvent( - { - "type": EVTYPE_REALIZATION_TIMEOUT, - "source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}", - "id": str(uuid.uuid1()), - } - ) - assert self._scheduler._events is not None - await self._scheduler._events.put(to_json(timeout_event)) - raise asyncio.CancelledError from None + if self.real.max_runtime is not None and self.real.max_runtime > 0: + cancel_task = asyncio.create_task(self._max_runtime_task()) + while not self.returncode.done(): + await asyncio.sleep(0.01) + returncode = await self.returncode if ( returncode == 0 @@ -129,6 +109,8 @@ async def threadsafer_future_wait(future: asyncio.Future[_T]) -> _T: await self.aborted.wait() await self._send(State.ABORTED) finally: + if cancel_task and not cancel_task.done(): + cancel_task.cancel() sem.release() async def __call__( @@ -151,6 +133,20 @@ async def __call__( ) logger.error(message) + async def _max_runtime_task(self) -> None: + assert self.real.max_runtime is not None + await asyncio.sleep(self.real.max_runtime) + timeout_event = CloudEvent( + { + "type": EVTYPE_REALIZATION_TIMEOUT, + "source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}", + "id": str(uuid.uuid1()), + } + ) + await self._scheduler._events.put(to_json(timeout_event)) + + self.returncode.cancel() # Triggers CancelledError + async def _send(self, state: State) -> None: status = STATE_TO_LEGACY[state] event = CloudEvent(