Skip to content

Commit

Permalink
Delay initialization of asyncio.Queue
Browse files Browse the repository at this point in the history
This is needed for Python 3.8 because the initialization is using
the current event loop, for which we have multiple given multiple
threads in Ert. Thus we need to initialize as late as possible
  • Loading branch information
berland committed Dec 11, 2023
1 parent db623ce commit 489b68f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 6 deletions.
9 changes: 5 additions & 4 deletions src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ class JobEvent(Enum):
class Driver(ABC):
"""Adapter for the HPC cluster."""

event_queue: asyncio.Queue[Tuple[int, JobEvent]]

def __init__(self) -> None:
self.event_queue = asyncio.Queue()
self.event_queue: Optional[asyncio.Queue[Tuple[int, JobEvent]]] = None

async def ainit(self) -> None:
if self.event_queue is None:
self.event_queue = asyncio.Queue()

@abstractmethod
async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None:
Expand All @@ -49,5 +51,4 @@ def create_poll_task(self) -> Optional[asyncio.Task[None]]:
Returns:
`asyncio.Task`, or None if polling is not applicable (eg. for LocalDriver)
"""

return None
3 changes: 3 additions & 0 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,7 @@ async def _send(self, state: State) -> None:
"queue_event_type": status,
},
)
if self._scheduler._events is None:
await self._scheduler.ainit()
assert self._scheduler._events is not None
await self._scheduler._events.put(to_json(event))
5 changes: 5 additions & 0 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ async def _wait_until_finish(
cwd=cwd,
preexec_fn=os.setpgrp,
)

if self.event_queue is None:
await self.ainit()
assert self.event_queue is not None

await self.event_queue.put((iens, JobEvent.STARTED))
try:
if await proc.wait() == 0:
Expand Down
17 changes: 15 additions & 2 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import ssl
import threading
from asyncio.queues import Queue
from dataclasses import asdict
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -50,13 +49,19 @@ def __init__(self, driver: Optional[Driver] = None) -> None:
self.driver = driver
self._jobs: MutableMapping[int, Job] = {}
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}
self._events: Queue[Any] = Queue()

self._events: Optional[asyncio.Queue[Any]] = None

self._ee_uri = ""
self._ens_id = ""
self._ee_cert: Optional[str] = None
self._ee_token: Optional[str] = None

async def ainit(self) -> None:
# While supporting Python 3.8, this statement must be delayed.
if self._events is None:
self._events = asyncio.Queue()

def add_realization(
self, real: Realization, callback_timeout: Callable[[int], None]
) -> None:
Expand Down Expand Up @@ -86,6 +91,10 @@ async def _publisher(self) -> None:
if self._ee_token:
headers["token"] = self._ee_token

if self._events is None:
await self.ainit()
assert self._events is not None

async with connect(
self._ee_uri,
ssl=tls,
Expand Down Expand Up @@ -132,6 +141,10 @@ async def execute(
return EVTYPE_ENSEMBLE_STOPPED

async def _process_event_queue(self) -> None:
if self.driver.event_queue is None:
await self.driver.ainit()
assert self.driver.event_queue is not None

while True:
iens, event = await self.driver.event_queue.get()
if event == JobEvent.STARTED:
Expand Down

0 comments on commit 489b68f

Please sign in to comment.