Skip to content

Commit

Permalink
Add unit tests for scheduler job states (#6824)
Browse files Browse the repository at this point in the history
* Add unit tests for scheduler job
  • Loading branch information
jonathan-eq authored Jan 5, 2024
1 parent a79a20b commit d509685
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/unit_tests/scheduler/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import asyncio
from collections.abc import Generator
from typing import Any, Coroutine, Literal

import pytest

from ert.scheduler.local_driver import LocalDriver
Expand Down Expand Up @@ -35,3 +39,21 @@ async def _kill(self, iens, *args):
@pytest.fixture
def mock_driver():
return MockDriver


class MockEvent(asyncio.Event):
def __init__(self):
self._mock_waited = asyncio.Future()
super().__init__()

async def wait(self) -> Coroutine[Any, Any, Literal[True]]:
self._mock_waited.set_result(True)
return await super().wait()

def done(self) -> bool:
return True


@pytest.fixture
def mock_event():
return MockEvent
116 changes: 116 additions & 0 deletions tests/unit_tests/scheduler/test_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import asyncio
import json
import shutil
from typing import List
from unittest.mock import AsyncMock, MagicMock

import pytest

import ert
from ert.ensemble_evaluator._builder._realization import Realization
from ert.load_status import LoadResult, LoadStatus
from ert.run_arg import RunArg
from ert.scheduler import Scheduler
from ert.scheduler.driver import JobEvent
from ert.scheduler.job import STATE_TO_LEGACY, Job, State


def create_scheduler():
sch = AsyncMock()
sch._events = asyncio.Queue()
sch.driver = AsyncMock()
return sch


@pytest.fixture
def realization():
run_arg = RunArg(
run_id="",
ensemble_storage=MagicMock(),
iens=0,
itr=0,
runpath="test_runpath",
job_name="test_job",
)
realization = Realization(
iens=run_arg.iens,
forward_models=[],
active=True,
max_runtime=None,
run_arg=run_arg,
num_cpu=1,
job_script=str(shutil.which("job_dispatch.py")),
)
return realization


async def assert_scheduler_events(
scheduler: Scheduler, job_events: List[JobEvent]
) -> None:
for job_event in job_events:
queue_event = await scheduler._events.get()
output = json.loads(queue_event.decode("utf-8"))
event = output.get("data").get("queue_event_type")
assert event == STATE_TO_LEGACY[job_event.value]
# should be no more events
assert scheduler._events.empty()


@pytest.mark.asyncio
async def test_submitted_job_is_cancelled(realization, mock_event):
scheduler = create_scheduler()
job = Job(scheduler, realization)
job._requested_max_submit = 1
job.started = mock_event()
job.aborted.set()
job_task = asyncio.create_task(job._submit_and_run_once(asyncio.Semaphore()))

await asyncio.wait_for(job.started._mock_waited, 5)

assert job_task.cancel()
await job_task
await assert_scheduler_events(
scheduler, [State.SUBMITTING, State.PENDING, State.ABORTING, State.ABORTED]
)
scheduler.driver.kill.assert_called_with(job.iens)
scheduler.driver.kill.assert_called_once()


@pytest.mark.parametrize(
"return_code, forward_model_ok_result, expected_final_event",
[
[0, LoadStatus.LOAD_SUCCESSFUL, State.COMPLETED],
[1, LoadStatus.LOAD_SUCCESSFUL, State.FAILED],
[0, LoadStatus.LOAD_FAILURE, State.FAILED],
[1, LoadStatus.LOAD_FAILURE, State.FAILED],
],
)
@pytest.mark.asyncio
async def test_job_submit_and_run_once(
return_code: int,
forward_model_ok_result,
expected_final_event: str,
realization: Realization,
monkeypatch,
):
monkeypatch.setattr(
ert.scheduler.job,
"forward_model_ok",
lambda _: LoadResult(forward_model_ok_result, ""),
)
scheduler = create_scheduler()
job = Job(scheduler, realization)
job._requested_max_submit = 1
job.started.set()
job.returncode.set_result(return_code)

await job._submit_and_run_once(asyncio.Semaphore())

await assert_scheduler_events(
scheduler,
[State.SUBMITTING, State.PENDING, State.RUNNING, expected_final_event],
)
scheduler.driver.submit.assert_called_with(
realization.iens, realization.job_script, cwd=realization.run_arg.runpath
)
scheduler.driver.submit.assert_called_once()

0 comments on commit d509685

Please sign in to comment.