From 4590a62230b78ab12f6e6a12dbbb5b28c69d39fe Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Fri, 15 Dec 2023 16:04:45 +0100 Subject: [PATCH] Add unit tests for scheduler job --- tests/unit_tests/scheduler/test_job.py | 226 +++++++++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 tests/unit_tests/scheduler/test_job.py diff --git a/tests/unit_tests/scheduler/test_job.py b/tests/unit_tests/scheduler/test_job.py new file mode 100644 index 00000000000..db12b190603 --- /dev/null +++ b/tests/unit_tests/scheduler/test_job.py @@ -0,0 +1,226 @@ + +import pytest +import shutil +import asyncio +import json + +import ert +from ert.scheduler.driver import JobEvent +from ert.load_status import LoadStatus, LoadResult +from ert.ensemble_evaluator._builder._realization import Realization +from ert.run_arg import RunArg +from ert.scheduler import Scheduler +from ert.scheduler.job import Job,STATE_TO_LEGACY, State +from unittest.mock import MagicMock, AsyncMock + +@pytest.fixture +def scheduler(): + sch = MagicMock() + sch._events = asyncio.Queue() + sch.driver = AsyncMock() + return sch + +@pytest.fixture +def realization(): + run_arg = RunArg( + run_id="", + ensemble_storage=MagicMock(), + iens=0, + itr=0, + runpath="test_runpath", + job_name="test_job" + ) + realization = Realization( + iens=run_arg.iens, + forward_models=[], + active = True, + max_runtime=None, + run_arg=run_arg, + num_cpu=1, + job_script=str(shutil.which("job_dispatch.py")) + ) + return realization + + +async def assert_scheduler_events(scheduler: Scheduler, job_events: list[JobEvent]) -> None: + await asyncio.sleep(0) + assert scheduler._events.qsize() == len(job_events) + for job_event in job_events: + output = json.loads((await scheduler._events.get()).decode("utf-8")) + event = output.get("data").get("queue_event_type") + assert event == STATE_TO_LEGACY[job_event.value] + # should be no more events + assert scheduler._events.empty() + await asyncio.sleep(0) + + +@pytest.mark.parametrize("return_code, forward_model_ok_result, expected_final_event", [ + [0, LoadStatus.LOAD_SUCCESSFUL, State.COMPLETED], + [1, LoadStatus.LOAD_SUCCESSFUL, State.FAILED], + [0, LoadStatus.LOAD_FAILURE, State.FAILED], + [1, LoadStatus.LOAD_FAILURE, State.FAILED] +]) +@pytest.mark.asyncio +async def test_job_call(return_code: int,forward_model_ok_result, expected_final_event: str, scheduler: Scheduler, realization: Realization, monkeypatch): + load_result = LoadResult(forward_model_ok_result, "") + monkeypatch.setattr( + ert.scheduler.job, + "forward_model_ok", + lambda _: load_result + ) + + job = Job(scheduler, realization) + + start_event = asyncio.Event() + semaphore = asyncio.BoundedSemaphore(1) + + asyncio.create_task(job(start_event, semaphore)) + start_event.set() + + # should not be running before semaphore is available + assert scheduler._events.empty() + + # should now start submitting + await asyncio.sleep(0) + # SHOULD SPLIT THIS TO MAKE SURE State.Starting is ran AFTER scheduler.driver.submit() + await assert_scheduler_events(scheduler, [State.SUBMITTING, State.STARTING]) + scheduler.driver.submit.assert_called_with(realization.iens, realization.job_script, cwd=realization.run_arg.runpath) + scheduler.driver.submit.assert_called_once() + # should not run before the started event is set + assert scheduler._events.empty() + assert not job.started.is_set() + + # set started event + job.started.set() + await assert_scheduler_events(scheduler, [State.RUNNING]) + job.returncode.set_result(return_code) + await asyncio.sleep(0) + await assert_scheduler_events(scheduler, [expected_final_event]) + + # should release semaphore regardless of ending + assert not semaphore.locked() + +# test job failing +# test job aborting +# test job running + +@pytest.mark.asyncio +async def test_job_acquires_semaphore(scheduler, realization): + job = Job(scheduler, realization) + sem = asyncio.Semaphore(1) + start_event = asyncio.Event() + asyncio.create_task(job(start_event, sem)) + await asyncio.sleep(0) + start_event.set() + await asyncio.sleep(0) + assert sem.locked() + + +@pytest.mark.asyncio +async def test_job_waits_for_start_event(scheduler, realization): + job = Job(scheduler, realization) + sem = asyncio.Semaphore(1) + start_event = asyncio.Event() + asyncio.create_task(job(start_event, sem)) + await asyncio.sleep(0) + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert not start_event.is_set() + assert not sem.locked() + await assert_scheduler_events(scheduler, []) + scheduler.driver.submit.assert_not_called() + + +@pytest.mark.asyncio +async def test_job_waits_for_semaphore(scheduler, realization): + job = Job(scheduler, realization) + sem = asyncio.Semaphore(1) + await sem.acquire() + + start_event = asyncio.Event() + asyncio.create_task(job(start_event, sem)) + start_event.set() + await asyncio.sleep(0) + await asyncio.sleep(0) + await asyncio.sleep(0) + assert sem.locked() + await assert_scheduler_events(scheduler, []) + sem.release() + await asyncio.sleep(0) + await asyncio.sleep(0) + await assert_scheduler_events(scheduler, [State.SUBMITTING, State.STARTING]) + scheduler.driver.submit.assert_called_once() + + +@pytest.mark.asyncio +async def test_job_waits_for_started_event(scheduler, realization): + job = Job(scheduler, realization) + sem = asyncio.Semaphore(1) + + start_event = asyncio.Event() + asyncio.create_task(job(start_event, sem)) + start_event.set() + + await assert_scheduler_events(scheduler, [State.SUBMITTING, State.STARTING]) + + job.started.set() + await assert_scheduler_events(scheduler, [State.RUNNING]) + + +@pytest.mark.asyncio +async def test_job_waits_for_started_event(scheduler, realization): + job = Job(scheduler, realization) + sem = asyncio.Semaphore(1) + + start_event = asyncio.Event() + asyncio.create_task(job(start_event, sem)) + start_event.set() + + await assert_scheduler_events(scheduler, [State.SUBMITTING, State.STARTING]) + + job.started.set() + await assert_scheduler_events(scheduler, [State.RUNNING]) + + +@pytest.mark.asyncio +async def test_job_releases_semaphore_on_exception(scheduler, realization): + + scheduler.driver.submit = AsyncMock(side_effect=ZeroDivisionError) + job = Job(scheduler, realization) + start_event = asyncio.Event() + semaphore = asyncio.BoundedSemaphore(1) + + job_task = asyncio.create_task(job(start_event, semaphore)) + start_event.set() + # should now start submitting + await asyncio.sleep(0) + await asyncio.sleep(0) + with pytest.raises(ZeroDivisionError): + job_task.result() + assert not semaphore.locked() + + +@pytest.mark.asyncio +async def test_job_is_cancelled(scheduler: Scheduler, realization): + job = Job(scheduler, realization) + sem = asyncio.Semaphore(1) + + start_event = asyncio.Event() + job_task = asyncio.create_task(job(start_event, sem)) + start_event.set() + await asyncio.sleep(0) + assert not job_task.cancelled() + clear_queue(scheduler._events) + assert job_task.cancel() + await assert_scheduler_events(scheduler, [State.ABORTING]) + scheduler.driver.kill.assert_called_with(job.iens) + scheduler.driver.kill.assert_called_once() + job.aborted.set() + await assert_scheduler_events(scheduler, [State.ABORTED]) + + +def clear_queue(queue: asyncio.Queue) -> None: + print(f"Clearing {queue.qsize()} items") + while not queue.empty(): + queue.get_nowait() \ No newline at end of file