Skip to content

Commit

Permalink
Eestablish connection and empty the event queue before cancelling tasks
Browse files Browse the repository at this point in the history
- Add _publisher_done event and CLOSE_PUBLISHER_SENTINEL to make sure that the connection was established and all events were sent before the cancellation happens.
- Supress CancelledError when task gets cancelled for long running jobs
- Ignore cancellation in job task
- Add test for scheduler publishings its events to a websocket with
publisher_done set Event.

Co-authored-by: Håvard Berland <[email protected]>
  • Loading branch information
xjules and berland committed Apr 8, 2024
1 parent f7d8650 commit d3b2544
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def __call__(
)
logger.warning(message)
self.returncode = asyncio.Future()
self.started = asyncio.Event()
self.started.clear()
else:
await self._send(State.FAILED)

Expand Down
14 changes: 13 additions & 1 deletion src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from websockets.client import connect

from ert.constant_filenames import CERT_FILE
from ert.job_queue.queue import EVTYPE_ENSEMBLE_CANCELLED, EVTYPE_ENSEMBLE_STOPPED
from ert.job_queue.queue import (
CLOSE_PUBLISHER_SENTINEL,
EVTYPE_ENSEMBLE_CANCELLED,
EVTYPE_ENSEMBLE_STOPPED,
)
from ert.scheduler.driver import SIGNAL_OFFSET, Driver
from ert.scheduler.event import FinishedEvent
from ert.scheduler.job import Job
Expand Down Expand Up @@ -106,6 +110,7 @@ def __init__(
self._ens_id = ens_id
self._ee_cert = ee_cert
self._ee_token = ee_token
self._publisher_done = asyncio.Event()

def kill_all_jobs(self) -> None:
assert self._loop
Expand Down Expand Up @@ -187,7 +192,11 @@ async def _publisher(self) -> None:
):
while True:
event = await self._events.get()
if event == CLOSE_PUBLISHER_SENTINEL:
self._publisher_done.set()
return
await conn.send(event)
self._events.task_done()

def add_dispatch_information_to_jobs_file(self) -> None:
for job in self._jobs.values():
Expand Down Expand Up @@ -225,6 +234,9 @@ async def _monitor_and_handle_tasks(
raise task_exception

if not self.is_active():
if self._ee_uri is not None:
await self._events.put(CLOSE_PUBLISHER_SENTINEL)
await self._publisher_done.wait()
for task in self._job_tasks.values():
if task.cancelled():
continue
Expand Down
1 change: 0 additions & 1 deletion tests/integration_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def sample_prior(nx, ny):
)


@pytest.mark.skip(reason="Very flaky with scheduler")
@pytest.mark.integration_test
@pytest.mark.usefixtures("copy_snake_oil_field", "using_scheduler")
def test_update_multiple_param():
Expand Down
1 change: 0 additions & 1 deletion tests/integration_tests/scheduler/test_openpbs_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def queue_name_config():
def test_that_openpbs_driver_ignores_qstat_flakiness(
text_to_ignore, caplog, capsys, create_mock_flaky_qstat
):

create_mock_flaky_qstat(text_to_ignore)
with open("poly.ert", mode="a+", encoding="utf-8") as f:
f.write("QUEUE_SYSTEM TORQUE\nNUM_REALIZATIONS 1")
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_tests/scheduler/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ async def test_submitted_job_is_cancelled(realization, mock_event):

await asyncio.wait_for(job.started._mock_waited, 5)

assert job_task.cancel()
job_task.cancel()
await job_task

await assert_scheduler_events(
scheduler, [State.SUBMITTING, State.PENDING, State.ABORTING, State.ABORTED]
)
Expand Down
42 changes: 42 additions & 0 deletions tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List

import pytest
import websockets
from cloudevents.http import from_json

from ert.config import QueueConfig
Expand Down Expand Up @@ -476,6 +477,47 @@ async def mock_failure(message, *args, **kwargs):
raise RuntimeError(message)


async def _mock_ws(set_when_done: asyncio.Event, handler, port: int):
async with websockets.server.serve(handler, "127.0.0.1", port):
await set_when_done.wait()


async def test_scheduler_publishes_to_websocket(
mock_driver, realization, unused_tcp_port
):
set_when_done = asyncio.Event()

events_received: List[str] = []

async def mock_ws_event_handler(websocket):
nonlocal events_received
async for message in websocket:
events_received.append(message)
await websocket.close()

websocket_server_task = asyncio.create_task(
_mock_ws(set_when_done, mock_ws_event_handler, unused_tcp_port)
)

driver = mock_driver()
sch = scheduler.Scheduler(
driver, [realization], ee_uri=f"ws://127.0.0.1:{unused_tcp_port}"
)
await sch.execute()
# publisher_done is set only if CLOSE_PUBLISHER_SENTINEL was received
assert sch._publisher_done.is_set()

set_when_done.set()
await websocket_server_task
assert [
json.loads(event)["data"]["queue_event_type"] for event in events_received
] == ["SUBMITTED", "PENDING", "RUNNING", "SUCCESS"]

assert (
sch._events.empty()
), "Schedulers internal event queue must be empty before finish"


@pytest.mark.timeout(5)
async def test_that_driver_poll_exceptions_are_propagated(mock_driver, realization):
driver = mock_driver()
Expand Down

0 comments on commit d3b2544

Please sign in to comment.