From 248d823b0ee38fc73c4089a967bdb3645f0e89c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Berland?= Date: Thu, 14 Dec 2023 13:19:30 +0100 Subject: [PATCH] Alternative with asyncio.wait_for() --- src/ert/scheduler/job.py | 50 ++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index d7c75f169af..174bff74f29 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, TypeVar from cloudevents.conversion import to_json from cloudevents.http import CloudEvent @@ -20,6 +20,7 @@ from ert.scheduler.scheduler import Scheduler logger = logging.getLogger(__name__) +_T = TypeVar("_T") class State(str, Enum): @@ -74,7 +75,6 @@ def driver(self) -> Driver: async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: await sem.acquire() - cancel_task: Optional[asyncio.Task[None]] = None try: await self._send(State.SUBMITTING) @@ -86,11 +86,31 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: await self.started.wait() await self._send(State.RUNNING) - if self.real.max_runtime is not None and self.real.max_runtime > 0: - cancel_task = asyncio.create_task(self._max_runtime_task()) - while not self.returncode.done(): - await asyncio.sleep(0.01) - returncode = await self.returncode + + async def threadsafer_future_wait(future: asyncio.Future[_T]) -> _T: + while not future.done(): + await asyncio.sleep(0.01) + returncode = await future + return returncode + + try: + returncode = await asyncio.wait_for( + threadsafer_future_wait(self.returncode), + timeout=self.real.max_runtime, + ) + except asyncio.TimeoutError: + assert self.real.max_runtime is not None + await asyncio.sleep(self.real.max_runtime) + timeout_event = CloudEvent( + { + "type": EVTYPE_REALIZATION_TIMEOUT, + "source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}", + "id": str(uuid.uuid1()), + } + ) + assert self._scheduler._events is not None + await self._scheduler._events.put(to_json(timeout_event)) + raise asyncio.CancelledError from None if ( returncode == 0 @@ -109,8 +129,6 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: await self.aborted.wait() await self._send(State.ABORTED) finally: - if cancel_task and not cancel_task.done(): - cancel_task.cancel() sem.release() async def __call__( @@ -133,20 +151,6 @@ async def __call__( ) logger.error(message) - async def _max_runtime_task(self) -> None: - assert self.real.max_runtime is not None - await asyncio.sleep(self.real.max_runtime) - timeout_event = CloudEvent( - { - "type": EVTYPE_REALIZATION_TIMEOUT, - "source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}", - "id": str(uuid.uuid1()), - } - ) - await self._scheduler._events.put(to_json(timeout_event)) - - self.returncode.cancel() # Triggers CancelledError - async def _send(self, state: State) -> None: status = STATE_TO_LEGACY[state] event = CloudEvent(