Skip to content

Commit

Permalink
Add publisher_connected event to make sure that the connection was es…
Browse files Browse the repository at this point in the history
…tahblished before cancellation happens
  • Loading branch information
xjules committed Apr 5, 2024
1 parent b4be6f4 commit a943506
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
self._ens_id = ens_id
self._ee_cert = ee_cert
self._ee_token = ee_token
self._publisher_connected = asyncio.Event()

def kill_all_jobs(self) -> None:
assert self._loop
Expand Down Expand Up @@ -184,15 +185,19 @@ async def _publisher(self) -> None:
ping_interval=60,
close_timeout=60,
):
self._publisher_connected.set()
try:
while True:
event = await self._events.get()
await conn.send(event)
self._events.task_done()
except asyncio.CancelledError:
while not self._events.empty():
event = await self._events.get_nowait()
await conn.send(event)
raise
finally:
self._publisher_connected.clear()

def add_dispatch_information_to_jobs_file(self) -> None:
for job in self._jobs.values():
Expand Down Expand Up @@ -228,7 +233,10 @@ async def _monitor_and_handle_tasks(
if task in scheduling_tasks:
await self._cancel_job_tasks()
raise task_exception

if not self.is_active():
if self._ee_uri is not None:
await self._publisher_connected.wait()
for task in self._job_tasks.values():
if task.cancelled():
continue
Expand Down
9 changes: 4 additions & 5 deletions tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,7 @@ async def mock_failure(message, *args, **kwargs):

async def _mock_ws(set_when_done: asyncio.Event, handler, port: int):
async with websockets.server.serve(handler, "127.0.0.1", port):
while not set_when_done.is_set():
await asyncio.sleep(0)
await set_when_done.wait()


async def test_scheduler_publishes_to_websocket(
Expand All @@ -494,6 +493,7 @@ async def mock_ws_event_handler(websocket):
nonlocal events_received
async for message in websocket:
events_received.append(message)
await websocket.close()

websocket_server_task = asyncio.create_task(
_mock_ws(set_when_done, mock_ws_event_handler, unused_tcp_port)
Expand All @@ -505,6 +505,8 @@ async def mock_ws_event_handler(websocket):
)
await sch.execute()

set_when_done.set()
await websocket_server_task
assert [
json.loads(event)["data"]["queue_event_type"] for event in events_received
] == ["SUBMITTED", "PENDING", "RUNNING", "SUCCESS"]
Expand All @@ -513,9 +515,6 @@ async def mock_ws_event_handler(websocket):
sch._events.empty()
), "Schedulers internal queue of events to be sent must be empty before it can finish"

set_when_done.set()
await websocket_server_task


@pytest.mark.timeout(5)
async def test_that_driver_poll_exceptions_are_propagated(mock_driver, realization):
Expand Down

0 comments on commit a943506

Please sign in to comment.