Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing src/ert/job_queue #6626

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,8 @@ ignore_missing_imports = True
[mypy-ruamel]
ignore_missing_imports = True

[mypy-statemachine]
ignore_missing_imports = True

[mypy-ert.callbacks]
ignore_errors = True
37 changes: 17 additions & 20 deletions src/ert/job_queue/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

if TYPE_CHECKING:
from ert.config import QueueConfig
from ert.job_queue import QueueableRealization, RealizationState
from ert.job_queue import RealizationState


class Driver(ABC):
def __init__(
self,
options: Optional[List[Tuple[str, str]]] = None,
):
self._options = {}
self._options: Dict[str, str] = {}

if options:
for key, value in options:
Expand All @@ -32,15 +32,15 @@ def has_option(self, option_key: str) -> bool:
return option_key in self._options

@abstractmethod
async def submit(self, realization: "RealizationState"):
async def submit(self, realization: "RealizationState") -> None:
pass

@abstractmethod
async def poll_statuses(self):
async def poll_statuses(self) -> None:
pass

@abstractmethod
async def kill(self, realization: "RealizationState"):
async def kill(self, realization: "RealizationState") -> None:
pass

@classmethod
Expand All @@ -60,10 +60,10 @@ def __init__(self, queue_config: List[Tuple[str, str]]):
self._currently_polling = False

@property
def optionnames(self):
def optionnames(self) -> List[str]:
return []

async def submit(self, realization: "RealizationState"):
async def submit(self, realization: "RealizationState") -> None:
"""Submit and *actually (a)wait* for the process to finish."""
realization.accept()
try:
Expand Down Expand Up @@ -93,16 +93,16 @@ async def submit(self, realization: "RealizationState"):
realization.runfail()
# TODO: fetch stdout/stderr

async def poll_statuses(self):
async def poll_statuses(self) -> None:
pass

async def kill(self, realization: "RealizationState"):
async def kill(self, realization: "RealizationState") -> None:
self._processes[realization].kill()
realization.verify_kill()


class LSFDriver(Driver):
def __init__(self, queue_options):
def __init__(self, queue_options: Optional[List[Tuple[str, str]]]):
super().__init__(queue_options)

self._realstate_to_lsfid: Dict["RealizationState", str] = {}
Expand All @@ -113,15 +113,14 @@ def __init__(self, queue_options):

self._currently_polling = False

async def submit(self, realization: "RealizationState"):
submit_cmd = [
async def submit(self, realization: "RealizationState") -> None:
submit_cmd: List[str] = [
"bsub",
"-J",
f"poly_{realization.realization.run_arg.iens}",
realization.realization.job_script,
realization.realization.run_arg.runpath,
str(realization.realization.job_script),
str(realization.realization.run_arg.runpath),
]
assert shutil.which(submit_cmd[0]) # does not propagate back..
process = await asyncio.create_subprocess_exec(
*submit_cmd,
stdout=asyncio.subprocess.PIPE,
Expand All @@ -142,13 +141,11 @@ async def submit(self, realization: "RealizationState"):
print(f"Submitted job {realization} and got LSF JOBID {lsf_id}")
except Exception:
# We should probably retry the submission, bsub stdout seems flaky.
print(f"ERROR: Could not parse lsf id from: {output}")
print(f"ERROR: Could not parse lsf id from: {output!r}")

async def poll_statuses(self) -> None:
if self._currently_polling:
# Don't repeat if we are called too often.
# So easy in async..
return self._statuses
return
self._currently_polling = True

if not self._realstate_to_lsfid:
Expand Down Expand Up @@ -198,6 +195,6 @@ async def poll_statuses(self) -> None:

self._currently_polling = False

async def kill(self, realization):
async def kill(self, realization: "RealizationState") -> None:
print(f"would like to kill {realization}")
pass
16 changes: 10 additions & 6 deletions src/ert/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ def realization_state(self, iens: int) -> RealizationState:
def count_realization_state(self, state: RealizationState) -> int:
return sum(real.current_state == state for real in self._realizations)

async def run_done_callback(self, state: RealizationState):
async def run_done_callback(self, state: RealizationState) -> Optional[LoadStatus]:
callback_status, status_msg = forward_model_ok(state.realization.run_arg)
if callback_status == LoadStatus.LOAD_SUCCESSFUL:
state.validate()
# todo: implement me
return None

@property
def stopped(self) -> bool:
Expand Down Expand Up @@ -162,7 +164,9 @@ def queue_size(self) -> int:

def _add_realization(self, realization: QueueableRealization) -> int:
self._realizations.append(
RealizationState(self, realization, retries=self._queue_config.max_submit - 1)
RealizationState(
self, realization, retries=self._queue_config.max_submit - 1
)
)
return len(self._realizations) - 1

Expand Down Expand Up @@ -311,9 +315,10 @@ async def execute(
await self.driver.poll_statuses()

for real in self._realizations:
if real.realization.max_runtime is None:
continue
if (
real.realization.max_runtime != None
and real.current_state == RealizationState.RUNNING
real.current_state == RealizationState.RUNNING
and real.start_time
and datetime.datetime.now() - real.start_time
> datetime.timedelta(seconds=real.realization.max_runtime)
Expand Down Expand Up @@ -353,7 +358,6 @@ async def execute(

return EVTYPE_ENSEMBLE_STOPPED


def add_realization_from_run_arg(
self,
run_arg: "RunArg",
Expand Down Expand Up @@ -412,7 +416,7 @@ def stop_long_running_realizations(
sum(real.runtime for real in completed) / finished_realizations
)

for job in self.job_list:
for job in self.job_list: # type: ignore
if job.runtime > LONG_RUNNING_FACTOR * average_runtime:
job.stop()

Expand Down
21 changes: 7 additions & 14 deletions src/ert/job_queue/realization_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,8 @@ class QueueableRealization: # Aka "Job" or previously "JobQueueNode"
max_runtime: Optional[int] = None
callback_timeout: Optional[Callable[[int], None]] = None

def __hash__(self):
# Elevate iens up to two levels? Check if it can be removed from run_arg
return self.run_arg.iens

def __repr__(self):
return str(self.run_arg.iens)


class RealizationState(StateMachine):
class RealizationState(StateMachine): # type: ignore
NOT_ACTIVE = State("NOT ACTIVE")
WAITING = State("WAITING", initial=True)
SUBMITTED = State("SUBMITTED")
Expand Down Expand Up @@ -102,7 +95,7 @@ def __init__(

donotgohere = UNKNOWN.to(STATUS_FAILURE)

def on_enter_state(self, target, event):
def on_enter_state(self, target: RealizationState) -> None:
if self.jobqueue._changes_to_publish is None:
return
if target in (
Expand All @@ -116,21 +109,21 @@ def on_enter_state(self, target, event):
change = {self.realization.run_arg.iens: target.id}
asyncio.create_task(self.jobqueue._changes_to_publish.put(change))

def on_enter_SUBMITTED(self):
def on_enter_SUBMITTED(self) -> None:
asyncio.create_task(self.jobqueue.driver.submit(self))

def on_enter_RUNNING(self):
def on_enter_RUNNING(self) -> None:
self.start_time = datetime.datetime.now()

def on_enter_EXIT(self):
def on_enter_EXIT(self) -> None:
if self.retries_left > 0:
self.retry()
self.retries_left -= 1
else:
self.invalidate()

def on_enter_DONE(self):
def on_enter_DONE(self) -> None:
asyncio.create_task(self.jobqueue.run_done_callback(self))

def on_enter_DO_KILL(self):
def on_enter_DO_KILL(self) -> None:
asyncio.create_task(self.jobqueue.driver.kill(self))
15 changes: 9 additions & 6 deletions tests/unit_tests/job_queue/test_job_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,12 @@ async def test_max_submit(tmpdir, monkeypatch, failing_script, max_submit_num):
await job_queue.stop_jobs_async()
await asyncio.gather(execute_task)


@pytest.mark.asyncio
@pytest.mark.parametrize("max_submit_num", [1, 3])
async def test_that_kill_queue_disregards_max_submit(tmpdir, max_submit_num, monkeypatch, simple_script):
async def test_that_kill_queue_disregards_max_submit(
tmpdir, max_submit_num, monkeypatch, simple_script
):
monkeypatch.chdir(tmpdir)
job_queue = create_local_queue(simple_script, max_submit=max_submit_num)
await job_queue.stop_jobs_async()
Expand All @@ -140,7 +143,11 @@ async def test_that_kill_queue_disregards_max_submit(tmpdir, max_submit_num, mon
print(tmpdir)
for iens in range(job_queue.queue_size):
assert not Path(f"dummy_path_{iens}/STATUS").exists()
assert job_queue.count_realization_state(RealizationState.IS_KILLED) == job_queue.queue_size
assert (
job_queue.count_realization_state(RealizationState.IS_KILLED)
== job_queue.queue_size
)


@pytest.mark.asyncio
@pytest.mark.timeout(5)
Expand Down Expand Up @@ -231,9 +238,6 @@ def test_add_dispatch_info_cert_none(tmpdir, monkeypatch, simple_script):
assert not (runpath / cert_file).exists()





@pytest.mark.skip(reason="Needs reimplementation")
def test_stop_long_running():
"""
Expand Down Expand Up @@ -283,7 +287,6 @@ def test_stop_long_running():
assert queue.snapshot()[i] == str(JobStatus.RUNNING)



@pytest.mark.usefixtures("use_tmpdir", "mock_fm_ok")
@pytest.mark.skip(reason="Needs reimplementation")
def test_num_cpu_submitted_correctly_lsf(tmpdir, simple_script):
Expand Down
Loading