Skip to content

Commit

Permalink
Alternative with asyncio.wait_for()
Browse files Browse the repository at this point in the history
  • Loading branch information
berland committed Dec 15, 2023
1 parent 89fa980 commit 248d823
Showing 1 changed file with 27 additions and 23 deletions.
50 changes: 27 additions & 23 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import uuid
from enum import Enum
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, TypeVar

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
Expand All @@ -20,6 +20,7 @@
from ert.scheduler.scheduler import Scheduler

logger = logging.getLogger(__name__)
_T = TypeVar("_T")


class State(str, Enum):
Expand Down Expand Up @@ -74,7 +75,6 @@ def driver(self) -> Driver:

async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
await sem.acquire()
cancel_task: Optional[asyncio.Task[None]] = None

try:
await self._send(State.SUBMITTING)
Expand All @@ -86,11 +86,31 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
await self.started.wait()

await self._send(State.RUNNING)
if self.real.max_runtime is not None and self.real.max_runtime > 0:
cancel_task = asyncio.create_task(self._max_runtime_task())
while not self.returncode.done():
await asyncio.sleep(0.01)
returncode = await self.returncode

async def threadsafer_future_wait(future: asyncio.Future[_T]) -> _T:
while not future.done():
await asyncio.sleep(0.01)
returncode = await future
return returncode

try:
returncode = await asyncio.wait_for(
threadsafer_future_wait(self.returncode),
timeout=self.real.max_runtime,
)
except asyncio.TimeoutError:
assert self.real.max_runtime is not None
await asyncio.sleep(self.real.max_runtime)
timeout_event = CloudEvent(
{
"type": EVTYPE_REALIZATION_TIMEOUT,
"source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
"id": str(uuid.uuid1()),
}
)
assert self._scheduler._events is not None
await self._scheduler._events.put(to_json(timeout_event))
raise asyncio.CancelledError from None

if (
returncode == 0
Expand All @@ -109,8 +129,6 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
await self.aborted.wait()
await self._send(State.ABORTED)
finally:
if cancel_task and not cancel_task.done():
cancel_task.cancel()
sem.release()

async def __call__(
Expand All @@ -133,20 +151,6 @@ async def __call__(
)
logger.error(message)

async def _max_runtime_task(self) -> None:
assert self.real.max_runtime is not None
await asyncio.sleep(self.real.max_runtime)
timeout_event = CloudEvent(
{
"type": EVTYPE_REALIZATION_TIMEOUT,
"source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
"id": str(uuid.uuid1()),
}
)
await self._scheduler._events.put(to_json(timeout_event))

self.returncode.cancel() # Triggers CancelledError

async def _send(self, state: State) -> None:
status = STATE_TO_LEGACY[state]
event = CloudEvent(
Expand Down

0 comments on commit 248d823

Please sign in to comment.