diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index c90bae21bd6..37e97a3fb82 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -145,7 +145,7 @@ async def __call__( ) logger.warning(message) self.returncode = asyncio.Future() - self.started = asyncio.Event() + self.started.clear() else: await self._send(State.FAILED) diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 201837cadda..26c60749c85 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -26,7 +26,11 @@ from websockets.client import connect from ert.constant_filenames import CERT_FILE -from ert.job_queue.queue import EVTYPE_ENSEMBLE_CANCELLED, EVTYPE_ENSEMBLE_STOPPED +from ert.job_queue.queue import ( + CLOSE_PUBLISHER_SENTINEL, + EVTYPE_ENSEMBLE_CANCELLED, + EVTYPE_ENSEMBLE_STOPPED, +) from ert.scheduler.driver import SIGNAL_OFFSET, Driver from ert.scheduler.event import FinishedEvent from ert.scheduler.job import Job @@ -106,6 +110,7 @@ def __init__( self._ens_id = ens_id self._ee_cert = ee_cert self._ee_token = ee_token + self._publisher_done = asyncio.Event() def kill_all_jobs(self) -> None: assert self._loop @@ -187,7 +192,11 @@ async def _publisher(self) -> None: ): while True: event = await self._events.get() + if event == CLOSE_PUBLISHER_SENTINEL: + self._publisher_done.set() + return await conn.send(event) + self._events.task_done() def add_dispatch_information_to_jobs_file(self) -> None: for job in self._jobs.values(): @@ -225,6 +234,9 @@ async def _monitor_and_handle_tasks( raise task_exception if not self.is_active(): + if self._ee_uri is not None: + await self._events.put(CLOSE_PUBLISHER_SENTINEL) + await self._publisher_done.wait() for task in self._job_tasks.values(): if task.cancelled(): continue diff --git a/tests/integration_tests/analysis/test_es_update.py b/tests/integration_tests/analysis/test_es_update.py index 79452284a5c..39f4648cb27 100644 --- a/tests/integration_tests/analysis/test_es_update.py +++ b/tests/integration_tests/analysis/test_es_update.py @@ -164,7 +164,6 @@ def sample_prior(nx, ny): ) -@pytest.mark.skip(reason="Very flaky with scheduler") @pytest.mark.integration_test @pytest.mark.usefixtures("copy_snake_oil_field", "using_scheduler") def test_update_multiple_param(): diff --git a/tests/integration_tests/scheduler/test_openpbs_driver.py b/tests/integration_tests/scheduler/test_openpbs_driver.py index 669b049c741..a4740a3ab0f 100644 --- a/tests/integration_tests/scheduler/test_openpbs_driver.py +++ b/tests/integration_tests/scheduler/test_openpbs_driver.py @@ -41,7 +41,6 @@ def queue_name_config(): def test_that_openpbs_driver_ignores_qstat_flakiness( text_to_ignore, caplog, capsys, create_mock_flaky_qstat ): - create_mock_flaky_qstat(text_to_ignore) with open("poly.ert", mode="a+", encoding="utf-8") as f: f.write("QUEUE_SYSTEM TORQUE\nNUM_REALIZATIONS 1") diff --git a/tests/unit_tests/scheduler/test_job.py b/tests/unit_tests/scheduler/test_job.py index 32ede9be8f4..2cc8d4b3fcd 100644 --- a/tests/unit_tests/scheduler/test_job.py +++ b/tests/unit_tests/scheduler/test_job.py @@ -67,8 +67,9 @@ async def test_submitted_job_is_cancelled(realization, mock_event): await asyncio.wait_for(job.started._mock_waited, 5) - assert job_task.cancel() + job_task.cancel() await job_task + await assert_scheduler_events( scheduler, [State.SUBMITTING, State.PENDING, State.ABORTING, State.ABORTED] ) diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index 33e2b8974e7..bf2463570ca 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -8,6 +8,7 @@ from typing import List import pytest +import websockets from cloudevents.http import from_json from ert.config import QueueConfig @@ -476,6 +477,47 @@ async def mock_failure(message, *args, **kwargs): raise RuntimeError(message) +async def _mock_ws(set_when_done: asyncio.Event, handler, port: int): + async with websockets.server.serve(handler, "127.0.0.1", port): + await set_when_done.wait() + + +async def test_scheduler_publishes_to_websocket( + mock_driver, realization, unused_tcp_port +): + set_when_done = asyncio.Event() + + events_received: List[str] = [] + + async def mock_ws_event_handler(websocket): + nonlocal events_received + async for message in websocket: + events_received.append(message) + await websocket.close() + + websocket_server_task = asyncio.create_task( + _mock_ws(set_when_done, mock_ws_event_handler, unused_tcp_port) + ) + + driver = mock_driver() + sch = scheduler.Scheduler( + driver, [realization], ee_uri=f"ws://127.0.0.1:{unused_tcp_port}" + ) + await sch.execute() + # publisher_done is set only if CLOSE_PUBLISHER_SENTINEL was received + assert sch._publisher_done.is_set() + + set_when_done.set() + await websocket_server_task + assert [ + json.loads(event)["data"]["queue_event_type"] for event in events_received + ] == ["SUBMITTED", "PENDING", "RUNNING", "SUCCESS"] + + assert ( + sch._events.empty() + ), "Schedulers internal event queue must be empty before finish" + + @pytest.mark.timeout(5) async def test_that_driver_poll_exceptions_are_propagated(mock_driver, realization): driver = mock_driver()