Skip to content

Commit

Permalink
Cancel correctly tasks from sync context in Scheduler
Browse files Browse the repository at this point in the history
When stopping the executing from ee, which runs in another thread, we need to use the correct loop
when cancelling the job tasks. Further, we just signal to cancel therefore we don't need to await
for the tasks to finish. This is handled in the Scheduler.execute - asyncio.gather.

There two functions (kill_all_jobs and cancel_all_jobs) to cancel the tasks in the Scheduler. kill_all_jobs is meant to be used from sync context.
  • Loading branch information
xjules committed Jan 12, 2024
1 parent dd3a01d commit f687bba
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
17 changes: 15 additions & 2 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def __init__(
}

self._events: asyncio.Queue[Any] = asyncio.Queue()
self._loop: Optional[asyncio.AbstractEventLoop] = None

self._average_job_runtime: float = 0
self._completed_jobs_num: int = 0
self.completed_jobs: asyncio.Queue[int] = asyncio.Queue()
Expand All @@ -76,6 +78,10 @@ def __init__(
self._ee_token = ee_token

def kill_all_jobs(self) -> None:
assert self._loop
asyncio.run_coroutine_threadsafe(self.cancel_all_jobs(), self._loop)

async def cancel_all_jobs(self) -> None:
self._cancelled = True
for task in self._tasks.values():
task.cancel()
Expand Down Expand Up @@ -148,6 +154,9 @@ async def execute(
self,
min_required_realizations: int = 0,
) -> str:
# We need to store the loop due to when calling
# cancel jobs from another thread
self._loop = asyncio.get_running_loop()
async with background_tasks() as cancel_when_execute_is_done:
cancel_when_execute_is_done(self._publisher())
cancel_when_execute_is_done(self._process_event_queue())
Expand All @@ -166,8 +175,12 @@ async def execute(
)

start.set()
for task in self._tasks.values():
await task
results = await asyncio.gather(
*self._tasks.values(), return_exceptions=True
)
for result in results:
if isinstance(result, Exception):
logger.error(result)

await self.driver.finish()

Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def kill():
await asyncio.wait_for(pre.wait(), timeout=1)

# Kill all jobs and wait for the scheduler to complete
sch.kill_all_jobs()
await sch.cancel_all_jobs() # this is equivalent to sch.kill_all_jobs()
await scheduler_task

assert pre.is_set()
Expand Down Expand Up @@ -272,7 +272,7 @@ async def kill():
scheduler_task = asyncio.create_task(sch.execute())

await now_kill_me.wait()
sch.kill_all_jobs()
await sch.cancel_all_jobs() # this is equivalent to sch.kill_all_jobs()

# Sleep until max_runtime must have kicked in:
await asyncio.sleep(1.1)
Expand Down

0 comments on commit f687bba

Please sign in to comment.