diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 20738e3d797..6ee982ea2d6 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -683,16 +683,16 @@ def adaptive_localization_progress_callback( start = time.time() for param_batch_idx in batches: X_local = temp_storage[param_group.name][param_batch_idx, :] - temp_storage[param_group.name][param_batch_idx, :] = ( - smoother_adaptive_es.assimilate( - X=X_local, - Y=S, - D=D, - alpha=1.0, # The user is responsible for scaling observation covariance (esmda usage) - correlation_threshold=module.correlation_threshold, - cov_YY=cov_YY, - progress_callback=adaptive_localization_progress_callback, - ) + temp_storage[param_group.name][ + param_batch_idx, : + ] = smoother_adaptive_es.assimilate( + X=X_local, + Y=S, + D=D, + alpha=1.0, # The user is responsible for scaling observation covariance (esmda usage) + correlation_threshold=module.correlation_threshold, + cov_YY=cov_YY, + progress_callback=adaptive_localization_progress_callback, ) _logger.info( f"Adaptive Localization of {param_group} completed in {(time.time() - start) / 60} minutes" @@ -849,9 +849,9 @@ def analysis_IES( ) if active_parameter_indices := param_group.index_list: X = temp_storage[param_group.name][active_parameter_indices, :] - temp_storage[param_group.name][active_parameter_indices, :] = ( - X + X @ sies_smoother.W / np.sqrt(len(iens_active_index) - 1) - ) + temp_storage[param_group.name][ + active_parameter_indices, : + ] = X + X @ sies_smoother.W / np.sqrt(len(iens_active_index) - 1) else: X = temp_storage[param_group.name] temp_storage[param_group.name] = X + X @ sies_smoother.W / np.sqrt( diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index 65d57811542..0b0505d4b0c 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -435,10 +435,12 @@ def __init__(self, job): ) @overload - def substitute(self, string: str) -> str: ... + def substitute(self, string: str) -> str: + ... @overload - def substitute(self, string: None) -> None: ... + def substitute(self, string: None) -> None: + ... def substitute(self, string): if string is None: diff --git a/src/ert/config/parsing/observations_parser.py b/src/ert/config/parsing/observations_parser.py index 158eaa5c909..4b44f3a475a 100644 --- a/src/ert/config/parsing/observations_parser.py +++ b/src/ert/config/parsing/observations_parser.py @@ -107,7 +107,9 @@ def parse(filename: str) -> ConfContent: ) -def _parse_content(content: str, filename: str) -> List[ +def _parse_content( + content: str, filename: str +) -> List[ Union[ SimpleHistoryDeclaration, Tuple[ObservationType, FileContextToken, Dict[FileContextToken, Any]], diff --git a/src/ert/config/response_config.py b/src/ert/config/response_config.py index 97f27be0d0d..6ebf3660191 100644 --- a/src/ert/config/response_config.py +++ b/src/ert/config/response_config.py @@ -12,7 +12,8 @@ class ResponseConfig(ABC): name: str @abstractmethod - def read_from_file(self, run_path: str, iens: int) -> xr.Dataset: ... + def read_from_file(self, run_path: str, iens: int) -> xr.Dataset: + ... def to_dict(self) -> Dict[str, Any]: data = dataclasses.asdict(self, dict_factory=CustomDict) diff --git a/src/ert/ensemble_evaluator/_builder/_legacy.py b/src/ert/ensemble_evaluator/_builder/_legacy.py index 188a783b00f..0de1df1fcc0 100644 --- a/src/ert/ensemble_evaluator/_builder/_legacy.py +++ b/src/ert/ensemble_evaluator/_builder/_legacy.py @@ -43,7 +43,8 @@ class _KillAllJobs(Protocol): - def kill_all_jobs(self) -> None: ... + def kill_all_jobs(self) -> None: + ... class LegacyEnsemble(Ensemble): diff --git a/src/ert/ensemble_evaluator/identifiers.py b/src/ert/ensemble_evaluator/identifiers.py index 56f8480c7b0..80a04089d8d 100644 --- a/src/ert/ensemble_evaluator/identifiers.py +++ b/src/ert/ensemble_evaluator/identifiers.py @@ -1,3 +1,5 @@ +from typing import Literal + ACTIVE = "active" CURRENT_MEMORY_USAGE = "current_memory_usage" DATA = "data" @@ -30,6 +32,15 @@ EVTYPE_FORWARD_MODEL_SUCCESS = "com.equinor.ert.forward_model_job.success" EVTYPE_FORWARD_MODEL_FAILURE = "com.equinor.ert.forward_model_job.failure" +EvGroupRealizationType = Literal[ + "com.equinor.ert.realization.failure", + "com.equinor.ert.realization.pending", + "com.equinor.ert.realization.running", + "com.equinor.ert.realization.success", + "com.equinor.ert.realization.unknown", + "com.equinor.ert.realization.waiting", + "com.equinor.ert.realization.timeout", +] EVGROUP_REALIZATION = { EVTYPE_REALIZATION_FAILURE, diff --git a/src/ert/scheduler/event_sender.py b/src/ert/scheduler/event_sender.py new file mode 100644 index 00000000000..fa9e893c02e --- /dev/null +++ b/src/ert/scheduler/event_sender.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import asyncio +import ssl +from typing import TYPE_CHECKING, Any, Mapping, Optional + +from cloudevents.conversion import to_json +from cloudevents.http import CloudEvent +from websockets import Headers, connect + +if TYPE_CHECKING: + from ert.ensemble_evaluator.identifiers import EvGroupRealizationType + + +class EventSender: + def __init__( + self, + ens_id: Optional[str], + ee_uri: Optional[str], + ee_cert: Optional[str], + ee_token: Optional[str], + ) -> None: + self.ens_id = ens_id + self.ee_uri = ee_uri + self.ee_cert = ee_cert + self.ee_token = ee_token + self.events: asyncio.Queue[CloudEvent] = asyncio.Queue() + + async def send( + self, + type: EvGroupRealizationType, + source: str, + attributes: Optional[Mapping[str, Any]] = None, + data: Optional[Mapping[str, Any]] = None, + ) -> None: + event = CloudEvent( + { + "type": type, + "source": f"/ert/ensemble/{self.ens_id}/{source}", + **(attributes or {}), + }, + data, + ) + await self.events.put(event) + + async def publisher(self) -> None: + if not self.ee_uri: + return + tls: Optional[ssl.SSLContext] = None + if self.ee_cert: + tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + tls.load_verify_locations(cadata=self.ee_cert) + headers = Headers() + if self.ee_token: + headers["token"] = self.ee_token + + async for conn in connect( + self.ee_uri, + ssl=tls, + extra_headers=headers, + open_timeout=60, + ping_timeout=60, + ping_interval=60, + close_timeout=60, + ): + while True: + event = await self.events.get() + await conn.send(to_json(event)) diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index ea69b91d10f..620ff606b3e 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -7,28 +7,24 @@ from contextlib import suppress from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Callable, Coroutine, List, Mapping, Optional -from cloudevents.conversion import to_json -from cloudevents.http import CloudEvent from lxml import etree from ert.callbacks import forward_model_ok from ert.constant_filenames import ERROR_file -from ert.job_queue.queue import _queue_state_event_type from ert.load_status import LoadStatus from ert.scheduler.driver import Driver +from ert.scheduler.event_sender import EventSender from ert.storage.realization_storage_state import RealizationStorageState if TYPE_CHECKING: from ert.ensemble_evaluator._builder._realization import Realization - from ert.scheduler.scheduler import Scheduler + from ert.ensemble_evaluator.identifiers import EvGroupRealizationType + from ert.scheduler.scheduler import SubmitSleeper logger = logging.getLogger(__name__) -# Duplicated to avoid circular imports -EVTYPE_REALIZATION_TIMEOUT = "com.equinor.ert.realization.timeout" - class State(str, Enum): WAITING = "WAITING" @@ -41,6 +37,17 @@ class State(str, Enum): ABORTED = "ABORTED" +STATE_TO_EE: Mapping[State, EvGroupRealizationType] = { + State.WAITING: "com.equinor.ert.realization.waiting", + State.SUBMITTING: "com.equinor.ert.realization.waiting", + State.PENDING: "com.equinor.ert.realization.pending", + State.RUNNING: "com.equinor.ert.realization.running", + State.COMPLETED: "com.equinor.ert.realization.success", + State.FAILED: "com.equinor.ert.realization.failure", + State.ABORTED: "com.equinor.ert.realization.failure", +} + + STATE_TO_LEGACY = { State.WAITING: "WAITING", State.SUBMITTING: "SUBMITTED", @@ -60,13 +67,18 @@ class Job: (LSF, PBS, SLURM, etc.) """ - def __init__(self, scheduler: Scheduler, real: Realization) -> None: + def __init__( + self, + real: Realization, + *, + on_complete: Optional[Callable[[int], Coroutine[None, None, None]]] = None, + ) -> None: self.real = real self.state = State.WAITING self.started = asyncio.Event() self.returncode: asyncio.Future[int] = asyncio.Future() - self._aborted = False - self._scheduler: Scheduler = scheduler + self.on_complete = on_complete + self._event_sender: Optional[EventSender] = None self._callback_status_msg: str = "" self._requested_max_submit: Optional[int] = None self._start_time: Optional[float] = None @@ -76,10 +88,6 @@ def __init__(self, scheduler: Scheduler, real: Realization) -> None: def iens(self) -> int: return self.real.iens - @property - def driver(self) -> Driver: - return self._scheduler.driver - @property def running_duration(self) -> float: if self._start_time: @@ -88,15 +96,20 @@ def running_duration(self) -> float: return time.time() - self._start_time return 0 - async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: + async def _submit_and_run_once( + self, + sem: asyncio.BoundedSemaphore, + driver: Driver, + submit_sleep: Optional[SubmitSleeper] = None, + ) -> None: await sem.acquire() timeout_task: Optional[asyncio.Task[None]] = None try: - if self._scheduler.submit_sleep_state: - await self._scheduler.submit_sleep_state.sleep_until_we_can_submit() + if submit_sleep: + await submit_sleep.sleep_until_we_can_submit() await self._send(State.SUBMITTING) - await self.driver.submit( + await driver.submit( self.real.iens, self.real.job_script, self.real.run_arg.runpath, @@ -109,7 +122,7 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: self._start_time = time.time() await self._send(State.RUNNING) - if self.real.max_runtime is not None and self.real.max_runtime > 0: + if (self.real.max_runtime or 0) > 0: timeout_task = asyncio.create_task(self._max_runtime_task()) returncode = await self.returncode @@ -124,12 +137,16 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: if callback_status == LoadStatus.LOAD_SUCCESSFUL: await self._send(State.COMPLETED) + self._end_time = time.time() + if self.on_complete is not None: + await self.on_complete(self.iens) else: assert callback_status in ( LoadStatus.LOAD_FAILURE, LoadStatus.TIME_MAP_FAILURE, ) await self._send(State.FAILED) + await self._handle_failure() else: await self._send(State.FAILED) @@ -138,23 +155,29 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: except asyncio.CancelledError: await self._send(State.ABORTING) - await self.driver.kill(self.iens) + await driver.kill(self.iens) with suppress(asyncio.CancelledError): await self.returncode await self._send(State.ABORTED) + await self._handle_aborted() finally: if timeout_task and not timeout_task.done(): timeout_task.cancel() sem.release() async def __call__( - self, start: asyncio.Event, sem: asyncio.BoundedSemaphore, max_submit: int = 2 + self, + sem: asyncio.BoundedSemaphore, + event_sender: EventSender, + driver: Driver, + max_submit: int = 2, + submit_sleep: Optional[SubmitSleeper] = None, ) -> None: + self._event_sender = event_sender self._requested_max_submit = max_submit - await start.wait() for attempt in range(max_submit): - await self._submit_and_run_once(sem) + await self._submit_and_run_once(sem, driver, submit_sleep) if self.returncode.cancelled() or ( self.returncode.done() and self.returncode.result() == 0 @@ -167,17 +190,16 @@ async def __call__( async def _max_runtime_task(self) -> None: assert self.real.max_runtime is not None await asyncio.sleep(self.real.max_runtime) - timeout_event = CloudEvent( - { - "type": EVTYPE_REALIZATION_TIMEOUT, - "source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}", - "id": str(uuid.uuid1()), - } - ) - assert self._scheduler._events is not None - await self._scheduler._events.put(to_json(timeout_event)) - self.returncode.cancel() # Triggers CancelledError + if self._event_sender is not None: + await self._event_sender.send( + "com.equinor.ert.realization.timeout", + f"real/{self.iens}", + attributes={"id": str(uuid.uuid1())}, + ) + + self._event_sender = None + self.returncode.cancel() async def _handle_failure(self) -> None: assert self._requested_max_submit is not None @@ -204,28 +226,23 @@ async def _handle_aborted(self) -> None: async def _send(self, state: State) -> None: self.state = state - if state == State.FAILED: - await self._handle_failure() - - elif state == State.ABORTED: - await self._handle_aborted() + if self._event_sender is None: + return - elif state == State.COMPLETED: - self._end_time = time.time() - await self._scheduler.completed_jobs.put(self.iens) + if (status := STATE_TO_EE.get(state)) is None: + # This message does not need to be propagated to the user + return - status = STATE_TO_LEGACY[state] - event = CloudEvent( - { - "type": _queue_state_event_type(status), - "source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}", + await self._event_sender.send( + status, + f"real/{self.iens}", + attributes={ "datacontenttype": "application/json", }, - { - "queue_event_type": status, + data={ + "queue_event_type": STATE_TO_LEGACY[state], }, ) - await self._scheduler._events.put(to_json(event)) def log_info_from_exit_file(exit_file_path: Path) -> None: diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 19e98ccac37..e5c9468a0a6 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -4,22 +4,28 @@ import json import logging import os -import ssl import time from collections import defaultdict from dataclasses import asdict from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, MutableMapping, Optional, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Mapping, + MutableMapping, + Optional, + Sequence, +) from pydantic.dataclasses import dataclass -from websockets import Headers -from websockets.client import connect from ert.async_utils import background_tasks from ert.constant_filenames import CERT_FILE from ert.job_queue.queue import EVTYPE_ENSEMBLE_CANCELLED, EVTYPE_ENSEMBLE_STOPPED from ert.scheduler.driver import Driver from ert.scheduler.event import FinishedEvent +from ert.scheduler.event_sender import EventSender from ert.scheduler.job import Job from ert.scheduler.job import State as JobState @@ -78,8 +84,9 @@ def __init__( if submit_sleep > 0: self.submit_sleep_state = SubmitSleeper(submit_sleep) - self._jobs: MutableMapping[int, Job] = { - real.iens: Job(self, real) for real in (realizations or []) + self._jobs: Mapping[int, Job] = { + real.iens: Job(real, on_complete=self._on_job_complete) + for real in (realizations or []) } self._events: asyncio.Queue[Any] = asyncio.Queue() @@ -87,16 +94,18 @@ def __init__( self._average_job_runtime: float = 0 self._completed_jobs_num: int = 0 - self.completed_jobs: asyncio.Queue[int] = asyncio.Queue() + self._completed_jobs: asyncio.Queue[int] = asyncio.Queue() self._cancelled = False self._max_submit = max_submit self._max_running = max_running - self._ee_uri = ee_uri - self._ens_id = ens_id - self._ee_cert = ee_cert - self._ee_token = ee_token + self.event_sender = EventSender( + ens_id=ens_id, + ee_uri=ee_uri, + ee_cert=ee_cert, + ee_token=ee_token, + ) def kill_all_jobs(self) -> None: assert self._loop @@ -110,15 +119,6 @@ async def cancel_all_jobs(self) -> None: for task in self._tasks.values(): task.cancel() - async def _update_avg_job_runtime(self) -> None: - while True: - iens = await self.completed_jobs.get() - self._average_job_runtime = ( - self._average_job_runtime * self._completed_jobs_num - + self._jobs[iens].running_duration - ) / (self._completed_jobs_num + 1) - self._completed_jobs_num += 1 - async def _stop_long_running_jobs( self, minimum_required_realizations: int, long_running_factor: float = 1.25 ) -> None: @@ -134,9 +134,6 @@ async def _stop_long_running_jobs( await task await asyncio.sleep(0.1) - def set_realization(self, realization: Realization) -> None: - self._jobs[realization.iens] = Job(self, realization) - def is_active(self) -> bool: return any(not task.done() for task in self._tasks.values()) @@ -146,30 +143,6 @@ def count_states(self) -> Dict[JobState, int]: counts[job.state] += 1 return counts - async def _publisher(self) -> None: - if not self._ee_uri: - return - tls: Optional[ssl.SSLContext] = None - if self._ee_cert: - tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - tls.load_verify_locations(cadata=self._ee_cert) - headers = Headers() - if self._ee_token: - headers["token"] = self._ee_token - - async with connect( - self._ee_uri, - ssl=tls, - extra_headers=headers, - open_timeout=60, - ping_timeout=60, - ping_interval=60, - close_timeout=60, - ) as conn: - while True: - event = await self._events.get() - await conn.send(event) - def add_dispatch_information_to_jobs_file(self) -> None: for job in self._jobs.values(): self._update_jobs_json(job.iens, job.real.run_arg.runpath) @@ -182,29 +155,32 @@ async def execute( # cancel jobs from another thread self._loop = asyncio.get_running_loop() async with background_tasks() as cancel_when_execute_is_done: - cancel_when_execute_is_done(self._publisher()) + cancel_when_execute_is_done(self.event_sender.publisher()) cancel_when_execute_is_done(self._process_event_queue()) cancel_when_execute_is_done(self.driver.poll()) if min_required_realizations > 0: cancel_when_execute_is_done( self._stop_long_running_jobs(min_required_realizations) ) - cancel_when_execute_is_done(self._update_avg_job_runtime()) - start = asyncio.Event() sem = asyncio.BoundedSemaphore(self._max_running or len(self._jobs)) for iens, job in self._jobs.items(): self._tasks[iens] = asyncio.create_task( - job(start, sem, self._max_submit) + job( + sem, + self.event_sender, + self.driver, + self._max_submit, + self.submit_sleep_state, + ) ) - start.set() results = await asyncio.gather( *self._tasks.values(), return_exceptions=True ) for result in results: if isinstance(result, Exception): - logger.error(result) + logger.error(result, exc_info=result) await self.driver.finish() @@ -230,15 +206,15 @@ async def _process_event_queue(self) -> None: def _update_jobs_json(self, iens: int, runpath: str) -> None: cert_path = f"{runpath}/{CERT_FILE}" - if self._ee_cert is not None: - Path(cert_path).write_text(self._ee_cert, encoding="utf-8") + if self.event_sender.ee_cert is not None: + Path(cert_path).write_text(self.event_sender.ee_cert, encoding="utf-8") jobs = _JobsJson( experiment_id=None, - ens_id=self._ens_id, + ens_id=self.event_sender.ens_id, real_id=iens, - dispatch_url=self._ee_uri, - ee_token=self._ee_token, - ee_cert_path=cert_path if self._ee_cert is not None else None, + dispatch_url=self.event_sender.ee_uri, + ee_token=self.event_sender.ee_token, + ee_cert_path=self.event_sender.ee_cert and cert_path, ) jobs_path = os.path.join(runpath, "jobs.json") with open(jobs_path, "r") as fp: @@ -246,3 +222,10 @@ def _update_jobs_json(self, iens: int, runpath: str) -> None: with open(jobs_path, "w") as fp: data.update(asdict(jobs)) json.dump(data, fp, indent=4) + + async def _on_job_complete(self, iens: int) -> None: + self._average_job_runtime = ( + self._average_job_runtime * self._completed_jobs_num + + self._jobs[iens].running_duration + ) / (self._completed_jobs_num + 1) + self._completed_jobs_num += 1 diff --git a/src/ert/shared/_doc_utils/ert_jobs.py b/src/ert/shared/_doc_utils/ert_jobs.py index 9f23ac5f156..15bc18b18e7 100644 --- a/src/ert/shared/_doc_utils/ert_jobs.py +++ b/src/ert/shared/_doc_utils/ert_jobs.py @@ -123,9 +123,9 @@ class _ErtDocumentation(SphinxDirective): def _divide_into_categories( jobs: Dict[str, JobDoc], ) -> Dict[str, Dict[str, List[_ForwardModelDocumentation]]]: - categories: Dict[str, Dict[str, List[_ForwardModelDocumentation]]] = ( - defaultdict(lambda: defaultdict(list)) - ) + categories: Dict[ + str, Dict[str, List[_ForwardModelDocumentation]] + ] = defaultdict(lambda: defaultdict(list)) for job_name, docs in jobs.items(): # Job names in ERT traditionally used upper case letters # for the names of the job. However, at some point duplicate diff --git a/src/ert/simulator/simulation_context.py b/src/ert/simulator/simulation_context.py index 2d6d91d5000..49bcfe759d3 100644 --- a/src/ert/simulator/simulation_context.py +++ b/src/ert/simulator/simulation_context.py @@ -9,7 +9,6 @@ import numpy as np from ert.config import HookRuntime -from ert.config.parsing.queue_system import QueueSystem from ert.enkf_main import create_run_path from ert.ensemble_evaluator import Realization from ert.job_queue import JobQueue, JobStatus @@ -61,17 +60,6 @@ async def _submit_and_run_jobqueue( max_runtime, ert.ert_config.preferred_num_cpu, ) - else: - realization = Realization( - iens=run_arg.iens, - forward_models=[], - active=True, - max_runtime=max_runtime, - run_arg=run_arg, - num_cpu=ert.ert_config.preferred_num_cpu, - job_script=ert.ert_config.queue_config.job_script, - ) - job_queue.set_realization(realization) required_realizations = 0 if ert.ert_config.analysis_config.stop_long_running: @@ -94,19 +82,6 @@ def __init__( self._ert = ert self._mask = mask - if ( - ert.ert_config.queue_config.queue_system in [QueueSystem.LOCAL] - and FeatureToggling.value("scheduler") is not False - ): - FeatureToggling._conf["scheduler"].value = True - if ert.ert_config.queue_config.queue_system != QueueSystem.LOCAL: - raise NotImplementedError() - driver = create_driver(ert.ert_config.queue_config) - self._job_queue = Scheduler( - driver, max_running=ert.ert_config.queue_config.max_running - ) - else: - self._job_queue = JobQueue(ert.ert_config.queue_config) # fill in the missing geo_id data global_substitutions = ert.ert_config.substitution_list global_substitutions[""] = _slug(sim_fs.name) @@ -125,6 +100,33 @@ def __init__( iteration=itr, ) + if FeatureToggling.is_enabled("scheduler"): + driver = create_driver(ert.ert_config.queue_config) + + max_runtime: Optional[int] = ert.ert_config.analysis_config.max_runtime + if max_runtime == 0: + max_runtime = None + + self._job_queue = Scheduler( + driver, + max_running=ert.ert_config.queue_config.max_running, + realizations=[ + Realization( + iens=run_arg.iens, + forward_models=[], + active=True, + max_runtime=max_runtime, + run_arg=run_arg, + num_cpu=ert.ert_config.preferred_num_cpu, + job_script=ert.ert_config.queue_config.job_script, + ) + for run_arg in self._run_context + ], + ) + + else: + self._job_queue = JobQueue(ert.ert_config.queue_config) + create_run_path(self._run_context, global_substitutions, self._ert.ert_config) self._ert.runWorkflows( HookRuntime.PRE_SIMULATION, None, self._run_context.sim_fs diff --git a/tests/unit_tests/scheduler/test_job.py b/tests/unit_tests/scheduler/test_job.py index bccd61f6be1..c16f3601a7e 100644 --- a/tests/unit_tests/scheduler/test_job.py +++ b/tests/unit_tests/scheduler/test_job.py @@ -1,24 +1,26 @@ import asyncio -import json import shutil from typing import List -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest -import ert +import ert.scheduler.job from ert.ensemble_evaluator._builder._realization import Realization from ert.load_status import LoadResult, LoadStatus from ert.run_arg import RunArg -from ert.scheduler import Scheduler +from ert.scheduler.event_sender import EventSender from ert.scheduler.job import STATE_TO_LEGACY, Job, State -def create_scheduler(): - sch = AsyncMock() - sch._events = asyncio.Queue() - sch.driver = AsyncMock() - return sch +@pytest.fixture +def driver(mock_driver): + return mock_driver() + + +@pytest.fixture +def event_sender(): + return EventSender(None, None, None, None) @pytest.fixture @@ -43,36 +45,31 @@ def realization(): return realization -async def assert_scheduler_events( - scheduler: Scheduler, job_events: List[State] -) -> None: +async def assert_events(event_sender: EventSender, job_events: List[State]) -> None: for job_event in job_events: - queue_event = await scheduler._events.get() - output = json.loads(queue_event.decode("utf-8")) - event = output.get("data").get("queue_event_type") + queue_event = await event_sender.events.get() + event = (queue_event.get_data() or {}).get("queue_event_type") assert event == STATE_TO_LEGACY[job_event] # should be no more events - assert scheduler._events.empty() + assert event_sender.events.empty() @pytest.mark.timeout(5) -async def test_submitted_job_is_cancelled(realization, mock_event): - scheduler = create_scheduler() - job = Job(scheduler, realization) - job._requested_max_submit = 1 +async def test_submitted_job_is_cancelled( + realization, mock_event, event_sender, driver +): + job = Job(realization) job.started = mock_event() job.returncode.cancel() - job_task = asyncio.create_task(job._submit_and_run_once(asyncio.BoundedSemaphore())) + job_task = asyncio.create_task( + job(asyncio.BoundedSemaphore(), event_sender, driver, max_submit=1) + ) await asyncio.wait_for(job.started._mock_waited, 5) assert job_task.cancel() await job_task - await assert_scheduler_events( - scheduler, [State.SUBMITTING, State.PENDING, State.ABORTING, State.ABORTED] - ) - scheduler.driver.kill.assert_called_with(job.iens) - scheduler.driver.kill.assert_called_once() + await assert_events(event_sender, [State.SUBMITTING, State.PENDING, State.ABORTED]) @pytest.mark.parametrize( @@ -90,6 +87,8 @@ async def test_job_submit_and_run_once( forward_model_ok_result, expected_final_event: State, realization: Realization, + event_sender, + driver, monkeypatch, ): monkeypatch.setattr( @@ -97,23 +96,13 @@ async def test_job_submit_and_run_once( "forward_model_ok", lambda _: LoadResult(forward_model_ok_result, ""), ) - scheduler = create_scheduler() - job = Job(scheduler, realization) - job._requested_max_submit = 1 + job = Job(realization) job.started.set() job.returncode.set_result(return_code) - await job._submit_and_run_once(asyncio.Semaphore()) + await job(asyncio.BoundedSemaphore(), event_sender, driver, max_submit=1) - await assert_scheduler_events( - scheduler, + await assert_events( + event_sender, [State.SUBMITTING, State.PENDING, State.RUNNING, expected_final_event], ) - scheduler.driver.submit.assert_called_with( - realization.iens, - realization.job_script, - realization.run_arg.runpath, - name=realization.run_arg.job_name, - runpath=realization.run_arg.runpath, - ) - scheduler.driver.submit.assert_called_once() diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index 03b6c4c3520..a3f4f6a4ee5 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -7,7 +7,6 @@ from typing import List import pytest -from cloudevents.http import from_json from flaky import flaky from ert.constant_filenames import CERT_FILE @@ -199,26 +198,28 @@ async def init(*args, **kwargs): @pytest.mark.timeout(10) async def test_max_runtime(realization, mock_driver): - wait_started = asyncio.Event() + has_timed_out = False + + class MockEventSender: + async def publisher(self): + pass + + async def send(self, state, *args, **kwargs): + if state == "com.equinor.ert.realization.timeout": + nonlocal has_timed_out + has_timed_out = True async def wait(): - wait_started.set() await asyncio.sleep(100) realization.max_runtime = 1 sch = scheduler.Scheduler(mock_driver(wait=wait), [realization]) + sch.event_sender = MockEventSender() - result = await asyncio.create_task(sch.execute()) - assert wait_started.is_set() + result = await sch.execute() assert result == EVTYPE_ENSEMBLE_STOPPED - - timeouteventfound = False - while not timeouteventfound and not sch._events.empty(): - event = await sch._events.get() - if from_json(event)["type"] == "com.equinor.ert.realization.timeout": - timeouteventfound = True - assert timeouteventfound + assert has_timed_out @pytest.mark.parametrize("max_running", [0, 1, 2, 10]) @@ -263,13 +264,20 @@ async def wait(): @pytest.mark.timeout(6) async def test_max_runtime_while_killing(realization, mock_driver): - wait_started = asyncio.Event() now_kill_me = asyncio.Event() + has_timed_out = False + + class MockEventSender: + async def publisher(self): + pass + + async def send(self, state, *args, **kwargs): + if state == "com.equinor.ert.realization.timeout": + nonlocal has_timed_out + has_timed_out = True async def wait(): # A realization function that lives forever if it was not killed - wait_started.set() - await asyncio.sleep(0.1) now_kill_me.set() await asyncio.sleep(1000) @@ -281,26 +289,21 @@ async def kill(): realization.max_runtime = 1 sch = scheduler.Scheduler(mock_driver(wait=wait, kill=kill), [realization]) + sch.event_sender = MockEventSender() scheduler_task = asyncio.create_task(sch.execute()) await now_kill_me.wait() + await asyncio.sleep(0) await sch.cancel_all_jobs() # this is equivalent to sch.kill_all_jobs() # Sleep until max_runtime must have kicked in: await asyncio.sleep(1.1) - timeouteventfound = False - while not timeouteventfound and not sch._events.empty(): - event = await sch._events.get() - if from_json(event)["type"] == "com.equinor.ert.realization.timeout": - timeouteventfound = True - # Assert that a timeout_event is actually emitted, because killing took a # long time, and that we should exit normally (asserting no bad things # happen just because we have two things killing the realization). - - assert timeouteventfound + assert has_timed_out await scheduler_task # The result from execute is that we were cancelled, not stopped