From f687bbad1a029a1be675885817ab29d618ebf9db Mon Sep 17 00:00:00 2001 From: xjules Date: Fri, 5 Jan 2024 16:27:28 +0100 Subject: [PATCH] Cancel correctly tasks from sync context in Scheduler When stopping the executing from ee, which runs in another thread, we need to use the correct loop when cancelling the job tasks. Further, we just signal to cancel therefore we don't need to await for the tasks to finish. This is handled in the Scheduler.execute - asyncio.gather. There two functions (kill_all_jobs and cancel_all_jobs) to cancel the tasks in the Scheduler. kill_all_jobs is meant to be used from sync context. --- src/ert/scheduler/scheduler.py | 17 +++++++++++++++-- tests/unit_tests/scheduler/test_scheduler.py | 4 ++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 761f39043b2..e13d42b2b0a 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -62,6 +62,8 @@ def __init__( } self._events: asyncio.Queue[Any] = asyncio.Queue() + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._average_job_runtime: float = 0 self._completed_jobs_num: int = 0 self.completed_jobs: asyncio.Queue[int] = asyncio.Queue() @@ -76,6 +78,10 @@ def __init__( self._ee_token = ee_token def kill_all_jobs(self) -> None: + assert self._loop + asyncio.run_coroutine_threadsafe(self.cancel_all_jobs(), self._loop) + + async def cancel_all_jobs(self) -> None: self._cancelled = True for task in self._tasks.values(): task.cancel() @@ -148,6 +154,9 @@ async def execute( self, min_required_realizations: int = 0, ) -> str: + # We need to store the loop due to when calling + # cancel jobs from another thread + self._loop = asyncio.get_running_loop() async with background_tasks() as cancel_when_execute_is_done: cancel_when_execute_is_done(self._publisher()) cancel_when_execute_is_done(self._process_event_queue()) @@ -166,8 +175,12 @@ async def execute( ) start.set() - for task in self._tasks.values(): - await task + results = await asyncio.gather( + *self._tasks.values(), return_exceptions=True + ) + for result in results: + if isinstance(result, Exception): + logger.error(result) await self.driver.finish() diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index 788c2a65fd5..d0ca81584f3 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -107,7 +107,7 @@ async def kill(): await asyncio.wait_for(pre.wait(), timeout=1) # Kill all jobs and wait for the scheduler to complete - sch.kill_all_jobs() + await sch.cancel_all_jobs() # this is equivalent to sch.kill_all_jobs() await scheduler_task assert pre.is_set() @@ -272,7 +272,7 @@ async def kill(): scheduler_task = asyncio.create_task(sch.execute()) await now_kill_me.wait() - sch.kill_all_jobs() + await sch.cancel_all_jobs() # this is equivalent to sch.kill_all_jobs() # Sleep until max_runtime must have kicked in: await asyncio.sleep(1.1)