From c1d2e2120102c3cda1abbb6c3e9d98bc5a8eb63a Mon Sep 17 00:00:00 2001 From: DanSava Date: Fri, 17 May 2024 12:34:57 +0300 Subject: [PATCH] Wait for nsf file synchronization Ensure that the files generated by a job match their expected checksums. It waits for the job's runpath to appear in the scheduler's checksum dictionary, verifies the checksums of the files, and logs any errors or discrepancies. --- .../reporting/event.py | 15 +++ .../reporting/message.py | 18 ++- .../reporting/statemachine.py | 7 +- src/_ert_forward_model_runner/runner.py | 31 ++++- src/ert/config/ert_config.py | 27 ++++ src/ert/enkf_main.py | 4 + src/ert/ensemble_evaluator/evaluator.py | 45 +++++-- src/ert/ensemble_evaluator/identifiers.py | 1 + src/ert/event_type_constants.py | 2 + src/ert/scheduler/job.py | 47 +++++++ src/ert/scheduler/scheduler.py | 65 +++++++++ .../status/test_tracking_integration.py | 11 ++ .../unit_tests/ensemble_evaluator/conftest.py | 18 +++ .../test_ensemble_legacy.py | 22 +--- .../ensemble_evaluator/test_scheduler.py | 67 ++++++++++ .../forward_model_runner/test_jobmanager.py | 58 +++++++- tests/unit_tests/scheduler/test_job.py | 124 ++++++++++++++++++ 17 files changed, 528 insertions(+), 34 deletions(-) create mode 100644 tests/unit_tests/ensemble_evaluator/test_scheduler.py diff --git a/src/_ert_forward_model_runner/reporting/event.py b/src/_ert_forward_model_runner/reporting/event.py index 3543a34ff73..30159b45650 100644 --- a/src/_ert_forward_model_runner/reporting/event.py +++ b/src/_ert_forward_model_runner/reporting/event.py @@ -17,6 +17,7 @@ from _ert_forward_model_runner.reporting.base import Reporter from _ert_forward_model_runner.reporting.message import ( _JOB_EXIT_FAILED_STRING, + Checksum, Exited, Finish, Init, @@ -29,11 +30,13 @@ _FORWARD_MODEL_START = "com.equinor.ert.forward_model_job.start" _FORWARD_MODEL_RUNNING = "com.equinor.ert.forward_model_job.running" _FORWARD_MODEL_SUCCESS = "com.equinor.ert.forward_model_job.success" +_FORWARD_MODEL_CHECKSUM = "com.equinor.ert.forward_model_job.checksum" _FORWARD_MODEL_FAILURE = "com.equinor.ert.forward_model_job.failure" _CONTENT_TYPE = "datacontenttype" _JOB_MSG_TYPE = "type" _JOB_SOURCE = "source" +_RUN_PATH = "run_path" logger = logging.getLogger(__name__) @@ -66,6 +69,7 @@ def __init__(self, evaluator_url, token=None, cert_path=None): self._statemachine = StateMachine() self._statemachine.add_handler((Init,), self._init_handler) self._statemachine.add_handler((Start, Running, Exited), self._job_handler) + self._statemachine.add_handler((Checksum,), self._checksum_handler) self._statemachine.add_handler((Finish,), self._finished_handler) self._ens_id = None @@ -193,3 +197,14 @@ def _finished_handler(self, msg): ) if self._event_publisher_thread.is_alive(): self._event_publisher_thread.join() + + def _checksum_handler(self, msg): + job_msg_attrs = { + _JOB_SOURCE: (f"/ert/ensemble/{self._ens_id}/real/{self._real_id}"), + _CONTENT_TYPE: "application/json", + _RUN_PATH: msg.run_path, + } + self._dump_event( + attributes={_JOB_MSG_TYPE: _FORWARD_MODEL_CHECKSUM, **job_msg_attrs}, + data=msg.data, + ) diff --git a/src/_ert_forward_model_runner/reporting/message.py b/src/_ert_forward_model_runner/reporting/message.py index 5d04155e12f..052b18cf350 100644 --- a/src/_ert_forward_model_runner/reporting/message.py +++ b/src/_ert_forward_model_runner/reporting/message.py @@ -1,12 +1,21 @@ import dataclasses from datetime import datetime as dt -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, Literal, Optional, TypedDict import psutil if TYPE_CHECKING: from _ert_forward_model_runner.job import Job + class _ChecksumDictBase(TypedDict): + type: Literal["file"] + path: str + + class ChecksumDict(_ChecksumDictBase, total=False): + md5sum: str + error: str + + _JOB_STATUS_SUCCESS = "Success" _JOB_STATUS_RUNNING = "Running" _JOB_STATUS_FAILURE = "Failure" @@ -119,3 +128,10 @@ class Exited(Message): def __init__(self, job, exit_code): super().__init__(job) self.exit_code = exit_code + + +class Checksum(Message): + def __init__(self, checksum_dict: Dict[str, "ChecksumDict"], run_path: str): + super().__init__() + self.data = checksum_dict + self.run_path = run_path diff --git a/src/_ert_forward_model_runner/reporting/statemachine.py b/src/_ert_forward_model_runner/reporting/statemachine.py index 680d9b1e0ab..6009a4c2ebf 100644 --- a/src/_ert_forward_model_runner/reporting/statemachine.py +++ b/src/_ert_forward_model_runner/reporting/statemachine.py @@ -2,6 +2,7 @@ from typing import Callable, Dict, Tuple, Type from _ert_forward_model_runner.reporting.message import ( + Checksum, Exited, Finish, Init, @@ -22,12 +23,14 @@ def __init__(self) -> None: logger.debug("Initializing state machines") initialized = (Init,) jobs = (Start, Running, Exited) + checksum = (Checksum,) finished = (Finish,) self._handler: Dict[Message, Callable[[Message], None]] = {} self._transitions = { None: initialized, - initialized: jobs + finished, - jobs: jobs + finished, + initialized: jobs + checksum + finished, + jobs: jobs + checksum + finished, + checksum: checksum + finished, } self._state = None diff --git a/src/_ert_forward_model_runner/runner.py b/src/_ert_forward_model_runner/runner.py index d6f19018203..5b62ca3e5c6 100644 --- a/src/_ert_forward_model_runner/runner.py +++ b/src/_ert_forward_model_runner/runner.py @@ -1,11 +1,15 @@ +import hashlib +import json import os +from pathlib import Path from _ert_forward_model_runner.job import Job -from _ert_forward_model_runner.reporting.message import Finish, Init +from _ert_forward_model_runner.reporting.message import Checksum, Finish, Init class ForwardModelRunner: def __init__(self, jobs_data): + self.jobs_data = jobs_data self.simulation_id = jobs_data.get("run_id") self.experiment_id = jobs_data.get("experiment_id") self.ens_id = jobs_data.get("ens_id") @@ -23,6 +27,27 @@ def __init__(self, jobs_data): self._set_environment() + def _read_manifest(self): + if not Path("manifest.json").exists(): + return None + with open("manifest.json", mode="r", encoding="utf-8") as f: + data = json.load(f) + return { + name: {"type": "file", "path": str(Path(file).absolute())} + for name, file in data.items() + } + + def _populate_checksums(self, manifest): + if not manifest: + return None + for info in manifest.values(): + path = Path(info["path"]) + if path.exists(): + info["md5sum"] = hashlib.md5(path.read_bytes()).hexdigest() + else: + info["error"] = f"Expected file {path} not created by forward model!" + return manifest + def run(self, names_of_jobs_to_run): # if names_of_jobs_to_run, create job_queue which contains jobs that # are to be run. @@ -30,7 +55,6 @@ def run(self, names_of_jobs_to_run): job_queue = self.jobs else: job_queue = [j for j in self.jobs if j.name() in names_of_jobs_to_run] - init_message = Init( job_queue, self.simulation_id, @@ -56,9 +80,12 @@ def run(self, names_of_jobs_to_run): yield status_update if not status_update.success(): + yield Checksum(checksum_dict=None, run_path=os.getcwd()) yield Finish().with_error("Not all jobs completed successfully.") return + checksum_dict = self._populate_checksums(self._read_manifest()) + yield Checksum(checksum_dict=checksum_dict, run_path=os.getcwd()) yield Finish() def _set_environment(self): diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index 81803123647..1aaf09e87bc 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -528,6 +528,33 @@ def _create_list_of_forward_model_steps_to_run( def forward_model_step_name_list(self) -> List[str]: return [j.name for j in self.forward_model_steps] + def manifest_to_json(self, iens: int = 0, iter: int = 0) -> Dict[str, Any]: + manifest = {} + # Add expected parameter files to manifest + if iter == 0: + for ( + name, + parameter_config, + ) in self.ensemble_config.parameter_configs.items(): + if parameter_config.forward_init and parameter_config.forward_init_file: + file_path = parameter_config.forward_init_file.replace( + "%d", str(iens) + ) + manifest[name] = file_path + # Add expected response files to manifest + for name, respons_config in self.ensemble_config.response_configs.items(): + input_file = str(respons_config.input_file) + if isinstance(respons_config, SummaryConfig): + input_file = input_file.replace("", str(iens)) + manifest[f"{name}_UNSMRY"] = f"{input_file}.UNSMRY" + manifest[f"{name}_SMSPEC"] = f"{input_file}.SMSPEC" + if isinstance(respons_config, GenDataConfig): + if respons_config.report_steps and iens in respons_config.report_steps: + manifest[name] = input_file.replace("%d", str(iens)) + elif "%d" not in input_file: + manifest[name] = input_file + return manifest + def forward_model_data_to_json( self, run_id: Optional[str] = None, diff --git a/src/ert/enkf_main.py b/src/ert/enkf_main.py index dfe93abc604..4003e37c606 100644 --- a/src/ert/enkf_main.py +++ b/src/ert/enkf_main.py @@ -234,6 +234,10 @@ def create_run_path( ) json.dump(forward_model_output, fptr) + # Write MANIFEST file to runpath use to avoid NFS sync issues + with open(run_path / "manifest.json", mode="w", encoding="utf-8") as fptr: + data = ert_config.manifest_to_json(run_arg.iens, run_arg.itr) + json.dump(data, fptr) run_context.runpaths.write_runpath_list( [run_context.iteration], run_context.active_realizations diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index b321221a46e..f3d2e18ec44 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -43,6 +43,7 @@ EVTYPE_ENSEMBLE_FAILED, EVTYPE_ENSEMBLE_STARTED, EVTYPE_ENSEMBLE_STOPPED, + EVTYPE_FORWARD_MODEL_CHECKSUM, ) from .snapshot import PartialSnapshot from .state import ( @@ -100,6 +101,18 @@ def config(self) -> EvaluatorServerConfig: def ensemble(self) -> Ensemble: return self._ensemble + async def forward_checksum(self, event: CloudEvent) -> None: + forward_event = CloudEvent( + { + "type": EVTYPE_FORWARD_MODEL_CHECKSUM, + "source": f"/ert/ensemble/{self.ensemble.id_}", + }, + {event["run_path"]: event.data}, + ) + await self._send_message( + to_json(forward_event, data_marshaller=evaluator_marshaller).decode() + ) + def _fm_handler(self, events: List[CloudEvent]) -> None: with self._snapshot_mutex: snapshot_update_event = self.ensemble.update_snapshot(events) @@ -171,18 +184,7 @@ async def _send_snapshot_update( EVTYPE_EE_SNAPSHOT_UPDATE, snapshot_update_event.to_dict(), ) - if message and self._clients: - # Note return_exceptions=True in gather. This fire-and-forget - # approach is currently how we deal with failures when trying - # to send udates to clients. Rationale is that if sending to - # the client fails, the websocket is down and we have no way - # to re-establish it. Thus, it becomes the responsibility of - # the client to re-connect if necessary, in which case the first - # update it receives will be a full snapshot. - await asyncio.gather( - *[client.send(message) for client in self._clients], - return_exceptions=True, - ) + await self._send_message(message) def _create_cloud_event( self, @@ -282,7 +284,10 @@ async def handle_dispatch( ) continue try: - await self._dispatcher.handle_event(event) + if event["type"] == EVTYPE_FORWARD_MODEL_CHECKSUM: + await self.forward_checksum(event) + else: + await self._dispatcher.handle_event(event) except BaseException as ex: # Exceptions include asyncio.InvalidStateError, and # anything that self._*_handler() can raise (updates @@ -430,3 +435,17 @@ def get_successful_realizations(self) -> List[int]: def _get_ens_id(source: str) -> str: # the ens_id will be found at /ert/ensemble/ens_id/... return source.split("/")[3] + + async def _send_message(self, message: Optional[str] = None) -> None: + if message and self._clients: + # Note return_exceptions=True in gather. This fire-and-forget + # approach is currently how we deal with failures when trying + # to send udates to clients. Rationale is that if sending to + # the client fails, the websocket is down and we have no way + # to re-establish it. Thus, it becomes the responsibility of + # the client to re-connect if necessary, in which case the first + # update it receives will be a full snapshot. + await asyncio.gather( + *[client.send(message) for client in self._clients], + return_exceptions=True, + ) diff --git a/src/ert/ensemble_evaluator/identifiers.py b/src/ert/ensemble_evaluator/identifiers.py index 8a8ce3abe8c..54fc81dcc1d 100644 --- a/src/ert/ensemble_evaluator/identifiers.py +++ b/src/ert/ensemble_evaluator/identifiers.py @@ -35,6 +35,7 @@ EVTYPE_FORWARD_MODEL_RUNNING = "com.equinor.ert.forward_model_job.running" EVTYPE_FORWARD_MODEL_SUCCESS = "com.equinor.ert.forward_model_job.success" EVTYPE_FORWARD_MODEL_FAILURE = "com.equinor.ert.forward_model_job.failure" +EVTYPE_FORWARD_MODEL_CHECKSUM = "com.equinor.ert.forward_model_job.checksum" EVGROUP_REALIZATION = { diff --git a/src/ert/event_type_constants.py b/src/ert/event_type_constants.py index ec956c7eb3e..02d76c98e2a 100644 --- a/src/ert/event_type_constants.py +++ b/src/ert/event_type_constants.py @@ -10,3 +10,5 @@ EVTYPE_ENSEMBLE_STOPPED = "com.equinor.ert.ensemble.stopped" EVTYPE_ENSEMBLE_CANCELLED = "com.equinor.ert.ensemble.cancelled" EVTYPE_ENSEMBLE_FAILED = "com.equinor.ert.ensemble.failed" + +EVTYPE_FORWARD_MODEL_CHECKSUM = "com.equinor.ert.forward_model_job.checksum" diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index a12fb01e80a..af867952036 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import hashlib import logging import time import uuid @@ -136,6 +137,8 @@ async def run(self, sem: asyncio.BoundedSemaphore, max_submit: int = 2) -> None: break if self.returncode.result() == 0: + if self._scheduler.wait_for_checksum(): + await self._verify_checksum() await self._handle_finished_forward_model() break @@ -167,6 +170,50 @@ async def _max_runtime_task(self) -> None: ) self.returncode.cancel() + async def _verify_checksum(self, timeout: int = 120) -> None: + # Wait for job runpath to be in the checksum dictionary + runpath = self.real.run_arg.runpath + while runpath not in self._scheduler.checksum: + if timeout <= 0: + break + timeout -= 1 + await asyncio.sleep(1) + + checksum = self._scheduler.checksum.get(runpath) + if checksum is None: + logger.warning(f"Checksum information not received for {runpath}") + return + + errors = "\n".join( + [info["error"] for info in checksum.values() if "error" in info] + ) + if errors: + logger.error(errors) + + valid_checksums = [info for info in checksum.values() if "error" not in info] + + # Wait for files in checksum + while not all(Path(info["path"]).exists() for info in valid_checksums): + if timeout <= 0: + break + timeout -= 1 + logger.debug("Waiting for disk synchronization") + await asyncio.sleep(1) + + for info in valid_checksums: + file_path = Path(info["path"]) + expected_md5sum = info.get("md5sum") + if file_path.exists() and expected_md5sum: + actual_md5sum = hashlib.md5(file_path.read_bytes()).hexdigest() + if expected_md5sum == actual_md5sum: + logger.debug(f"File {file_path} checksum successful.") + else: + logger.warning(f"File {file_path} checksum verification failed.") + elif file_path.exists() and expected_md5sum is None: + logger.warning(f"Checksum not received for file {file_path}") + else: + logger.error(f"Disk synchronization failed for {file_path}") + async def _handle_finished_forward_model(self) -> None: callback_status, status_msg = forward_model_ok(self.real.run_arg) if self._callback_status_msg: diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index da61798838c..c67c93ea6da 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -21,12 +21,16 @@ Sequence, ) +from aiohttp import ClientError +from cloudevents.exceptions import DataUnmarshallerError +from cloudevents.http import from_json from pydantic.dataclasses import dataclass from websockets import ConnectionClosed, Headers from websockets.client import connect from _ert.async_utils import get_running_loop from ert.constant_filenames import CERT_FILE +from ert.event_type_constants import EVTYPE_FORWARD_MODEL_CHECKSUM from ert.job_queue.queue import ( CLOSE_PUBLISHER_SENTINEL, EVTYPE_ENSEMBLE_CANCELLED, @@ -36,6 +40,7 @@ from ert.scheduler.event import FinishedEvent from ert.scheduler.job import Job from ert.scheduler.job import State as JobState +from ert.serialization import evaluator_unmarshaller if TYPE_CHECKING: from ert.ensemble_evaluator._builder._realization import Realization @@ -116,6 +121,22 @@ def __init__( self._ee_cert = ee_cert self._ee_token = ee_token self._publisher_done = asyncio.Event() + self._consumer_started = asyncio.Event() + self.checksum: Dict[str, Dict[str, Any]] = {} + self.checksum_listener: Optional[asyncio.Task[None]] = None + + async def start_manifest_listener(self) -> Optional[asyncio.Task[None]]: + if self._ee_uri is None or "dispatch" not in self._ee_uri: + return None + + self.checksum_listener = asyncio.create_task( + self._checksum_consumer(), name="consumer_task" + ) + await self._consumer_started.wait() + return self.checksum_listener + + def wait_for_checksum(self) -> bool: + return self._consumer_started.is_set() def kill_all_jobs(self) -> None: assert self._loop @@ -184,6 +205,47 @@ def count_states(self) -> Dict[JobState, int]: counts[job.state] += 1 return counts + async def _checksum_consumer(self) -> None: + if not self._ee_uri: + return + tls: Optional[ssl.SSLContext] = None + if self._ee_cert: + tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + tls.load_verify_locations(cadata=self._ee_cert) + headers = Headers() + if self._ee_token: + headers["token"] = self._ee_token + event = None + async for conn in connect( + self._ee_uri.replace("dispatch", "client"), + ssl=tls, + extra_headers=headers, + max_size=2**26, + max_queue=500, + open_timeout=5, + ping_timeout=60, + ping_interval=60, + close_timeout=60, + ): + try: + self._consumer_started.set() + async for message in conn: + try: + event = from_json( + str(message), data_unmarshaller=evaluator_unmarshaller + ) + if event["type"] == EVTYPE_FORWARD_MODEL_CHECKSUM: + self.checksum.update(event.data) + except DataUnmarshallerError: + logger.error( + "Scheduler checksum consumer reviced unknown message" + ) + except (ConnectionRefusedError, ConnectionClosed, ClientError) as exc: + self._consumer_started.clear() + logger.debug( + f"Scheduler connection to EnsembleEvaluator went down: {exc}" + ) + async def _publisher(self) -> None: if not self._ee_uri: return @@ -268,6 +330,7 @@ async def execute( self, min_required_realizations: int = 0, ) -> str: + listener_task = await self.start_manifest_listener() scheduling_tasks = [ asyncio.create_task(self._publisher(), name="publisher_task"), asyncio.create_task( @@ -275,6 +338,8 @@ async def execute( ), asyncio.create_task(self.driver.poll(), name="poll_task"), ] + if listener_task is not None: + scheduling_tasks.append(listener_task) if min_required_realizations > 0: scheduling_tasks.append( diff --git a/tests/integration_tests/status/test_tracking_integration.py b/tests/integration_tests/status/test_tracking_integration.py index a20bfa19468..b3c7e75ca4a 100644 --- a/tests/integration_tests/status/test_tracking_integration.py +++ b/tests/integration_tests/status/test_tracking_integration.py @@ -32,6 +32,7 @@ ENSEMBLE_SMOOTHER_MODE, TEST_RUN_MODE, ) +from ert.shared.feature_toggling import FeatureScheduler class Events: @@ -412,3 +413,13 @@ def test_tracking_missing_ecl(tmpdir, caplog, storage): f"{Path().absolute()}/simulations/realization-0/" "iter-0/ECLIPSE_CASE" ) in failures[0].failed_msg + if FeatureScheduler._value: + case = f"{Path().absolute()}/simulations/realization-0/iter-0/ECLIPSE_CASE" + assert ( + f"Expected file {case}.UNSMRY not created by forward model!\nExpected " + f"file {case}.SMSPEC not created by forward model!" + ) in caplog.messages + assert ( + f"Expected file {case}.UNSMRY not created by forward model!\nExpected " + f"file {case}.SMSPEC not created by forward model!" + ) in failures[0].failed_msg diff --git a/tests/unit_tests/ensemble_evaluator/conftest.py b/tests/unit_tests/ensemble_evaluator/conftest.py index d6514b55acf..b6132d9b6db 100644 --- a/tests/unit_tests/ensemble_evaluator/conftest.py +++ b/tests/unit_tests/ensemble_evaluator/conftest.py @@ -1,12 +1,15 @@ +import asyncio import json import os import stat from pathlib import Path +from typing import Any, Callable, Coroutine from unittest.mock import MagicMock, Mock import pytest import ert.ensemble_evaluator +from _ert.async_utils import new_event_loop from ert.config import QueueConfig, QueueSystem from ert.config.ert_config import _forward_model_step_from_config_file from ert.ensemble_evaluator.config import EvaluatorServerConfig @@ -189,3 +192,18 @@ def evaluator(make_ee_config): ) yield ee ee.stop() + + +@pytest.fixture(name="run_monitor_in_loop") +def monitor_loop(): + def run_monitor_in_loop( + monitor_func: Callable[[], Coroutine[Any, Any, None]], + ) -> bool: + loop = new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(monitor_func()) + finally: + loop.close() + + return run_monitor_in_loop diff --git a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py index 0c10d9e7c9f..2a2feea5b5c 100644 --- a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py +++ b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py @@ -1,13 +1,10 @@ -import asyncio import contextlib import os -from typing import Any, Callable, Coroutine from unittest.mock import MagicMock, patch import pytest from websockets.exceptions import ConnectionClosed -from _ert.async_utils import new_event_loop from ert.config import QueueConfig from ert.ensemble_evaluator import Monitor, identifiers, state from ert.ensemble_evaluator.config import EvaluatorServerConfig @@ -17,18 +14,11 @@ from ert.shared.feature_toggling import FeatureScheduler -def run_monitor_in_loop(monitor_func: Callable[[], Coroutine[Any, Any, None]]) -> bool: - loop = new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete(monitor_func()) - finally: - loop.close() - - @pytest.mark.timeout(60) @pytest.mark.usefixtures("using_scheduler") -def test_run_legacy_ensemble(tmpdir, make_ensemble_builder, monkeypatch): +def test_run_legacy_ensemble( + tmpdir, make_ensemble_builder, monkeypatch, run_monitor_in_loop +): num_reals = 2 custom_port_range = range(1024, 65535) with tmpdir.as_cwd(): @@ -66,7 +56,9 @@ async def _run_monitor(): @pytest.mark.timeout(60) @pytest.mark.usefixtures("using_scheduler") -def test_run_and_cancel_legacy_ensemble(tmpdir, make_ensemble_builder, monkeypatch): +def test_run_and_cancel_legacy_ensemble( + tmpdir, make_ensemble_builder, monkeypatch, run_monitor_in_loop +): num_reals = 2 custom_port_range = range(1024, 65535) with tmpdir.as_cwd(): @@ -115,7 +107,7 @@ async def _run_monitor(): @pytest.mark.timeout(10) def test_run_legacy_ensemble_with_bare_exception( - tmpdir, make_ensemble_builder, monkeypatch + tmpdir, make_ensemble_builder, monkeypatch, run_monitor_in_loop ): """This test function is not ported to Scheduler, as it will not catch general exceptions.""" diff --git a/tests/unit_tests/ensemble_evaluator/test_scheduler.py b/tests/unit_tests/ensemble_evaluator/test_scheduler.py new file mode 100644 index 00000000000..16b70d73575 --- /dev/null +++ b/tests/unit_tests/ensemble_evaluator/test_scheduler.py @@ -0,0 +1,67 @@ +import asyncio +import json +import logging +from pathlib import Path + +import pytest + +from ert.ensemble_evaluator import Monitor, identifiers, state +from ert.ensemble_evaluator.config import EvaluatorServerConfig +from ert.ensemble_evaluator.evaluator import EnsembleEvaluator + + +@pytest.mark.timeout(60) +def test_scheduler_receives_checksum_and_waits_for_disk_sync( + tmpdir, make_ensemble_builder, monkeypatch, caplog, run_monitor_in_loop +): + num_reals = 1 + custom_port_range = range(1024, 65535) + + async def rename_and_wait(): + Path("real_0/job_test_file").rename("real_0/test") + while "Waiting for disk synchronization" not in caplog.messages: + await asyncio.sleep(0.1) + Path("real_0/test").rename("real_0/job_test_file") + + async def _run_monitor(): + async with Monitor(config) as monitor: + async for e in monitor.track(): + if e["type"] == identifiers.EVTYPE_FORWARD_MODEL_CHECKSUM: + # Monitor got the checksum message renaming the file + # before the scheduler gets the same message + try: + await asyncio.wait_for(rename_and_wait(), timeout=5) + except TimeoutError: + await monitor.signal_done() + if e["type"] in ( + identifiers.EVTYPE_EE_SNAPSHOT_UPDATE, + identifiers.EVTYPE_EE_SNAPSHOT, + ) and e.data.get(identifiers.STATUS) in [ + state.ENSEMBLE_STATE_FAILED, + state.ENSEMBLE_STATE_STOPPED, + ]: + await monitor.signal_done() + return True + + with tmpdir.as_cwd(): + ensemble = make_ensemble_builder(monkeypatch, tmpdir, num_reals, 2).build() + + # Creating testing manifest file + with open("real_0/manifest.json", mode="w", encoding="utf-8") as f: + json.dump({"file": "job_test_file"}, f) + file_path = Path("real_0/job_test_file") + file_path.write_text("test") + # actual_md5sum = hashlib.md5(file_path.read_bytes()).hexdigest() + config = EvaluatorServerConfig( + custom_port_range=custom_port_range, + custom_host="127.0.0.1", + use_token=False, + generate_cert=False, + ) + evaluator = EnsembleEvaluator(ensemble, config, 0) + with caplog.at_level(logging.DEBUG): + evaluator.start_running() + run_monitor_in_loop(_run_monitor) + assert "Waiting for disk synchronization" in caplog.messages + assert f"File {file_path.absolute()} checksum successful." in caplog.messages + assert evaluator._ensemble.status == state.ENSEMBLE_STATE_STOPPED diff --git a/tests/unit_tests/forward_model_runner/test_jobmanager.py b/tests/unit_tests/forward_model_runner/test_jobmanager.py index d1d437838d0..482acbc8770 100644 --- a/tests/unit_tests/forward_model_runner/test_jobmanager.py +++ b/tests/unit_tests/forward_model_runner/test_jobmanager.py @@ -5,7 +5,7 @@ import pytest -from _ert_forward_model_runner.reporting.message import Exited, Start +from _ert_forward_model_runner.reporting.message import Checksum, Exited, Start from _ert_forward_model_runner.runner import ForwardModelRunner from ert.config import ErtConfig, ForwardModelStep from ert.config.ert_config import _forward_model_step_from_config_file @@ -123,6 +123,62 @@ def test_run_multiple_ok(): assert os.path.getsize(f"mkdir_err.{dir_number}") == 0 +@pytest.mark.usefixtures("use_tmpdir") +def test_when_forward_model_contains_multiple_jobs_just_one_checksum_status_is_given(): + joblist = [] + file_list = ["1", "2", "3", "4", "5"] + manifest = {} + for job_index in file_list: + manifest[f"file_{job_index}"] = job_index + job = { + "name": "TOUCH", + "executable": "touch", + "stdout": f"touch_out.{job_index}", + "stderr": f"touch_err.{job_index}", + "argList": [job_index], + } + joblist.append(job) + with open("manifest.json", "w", encoding="utf-8") as f: + json.dump(manifest, f) + + jobm = ForwardModelRunner(create_jobs_json(joblist)) + + statuses = [s for s in list(jobm.run([])) if isinstance(s, Checksum)] + assert len(statuses) == 1 + assert len(statuses[0].data) == 5 + + +@pytest.mark.usefixtures("use_tmpdir") +def test_when_manifest_file_is_not_created_by_fm_runner_checksum_contains_error(): + joblist = [] + file_name = "test" + manifest = {"file_1": f"{file_name}"} + + joblist.append( + { + "name": "TOUCH", + "executable": "touch", + "stdout": "touch_out.test", + "stderr": "touch_err.test", + "argList": ["not_test"], + } + ) + with open("manifest.json", "w", encoding="utf-8") as f: + json.dump(manifest, f) + + jobm = ForwardModelRunner(create_jobs_json(joblist)) + + checksum_msg = [s for s in list(jobm.run([])) if isinstance(s, Checksum)] + assert len(checksum_msg) == 1 + info = checksum_msg[0].data["file_1"] + assert "md5sum" not in info + assert "error" in info + assert ( + f"Expected file {os.getcwd()}/{file_name} not created by forward model!" + in info["error"] + ) + + @pytest.mark.usefixtures("use_tmpdir") def test_run_multiple_fail_only_runs_one(): joblist = [] diff --git a/tests/unit_tests/scheduler/test_job.py b/tests/unit_tests/scheduler/test_job.py index 0ddcd6f88af..f6492790d3d 100644 --- a/tests/unit_tests/scheduler/test_job.py +++ b/tests/unit_tests/scheduler/test_job.py @@ -1,6 +1,8 @@ import asyncio import json +import logging import shutil +from functools import partial from pathlib import Path from typing import List from unittest.mock import AsyncMock, MagicMock @@ -11,6 +13,7 @@ from ert.ensemble_evaluator._builder._realization import Realization from ert.load_status import LoadResult, LoadStatus from ert.run_arg import RunArg +from ert.run_models.base_run_model import captured_logs from ert.scheduler import Scheduler from ert.scheduler.job import STATE_TO_LEGACY, Job, State @@ -19,6 +22,7 @@ def create_scheduler(): sch = AsyncMock() sch._events = asyncio.Queue() sch.driver = AsyncMock() + sch.wait_for_checksum = lambda: False sch._cancelled = False return sch @@ -121,6 +125,7 @@ async def test_job_run_sends_expected_events( forward_model_ok_result, "" ) job = Job(scheduler, realization) + job._verify_checksum = partial(job._verify_checksum, timeout=0) job.started.set() job_run_task = asyncio.create_task( @@ -155,3 +160,122 @@ async def test_job_run_sends_expected_events( num_cpu=realization.num_cpu, ) assert scheduler.driver.submit.call_count == max_submit + + +@pytest.mark.asyncio +async def test_when_waiting_for_disk_sync_times_out_an_error_is_logged( + realization: Realization, monkeypatch +): + scheduler = create_scheduler() + scheduler.wait_for_checksum = lambda: True + file_path = "does/not/exist" + scheduler.checksum = { + "test_runpath": { + "file": { + "path": file_path, + "md5sum": "something", + } + } + } + log_msgs = [] + job = Job(scheduler, realization) + job._verify_checksum = partial(job._verify_checksum, timeout=0) + job.started.set() + + with captured_logs(log_msgs, logging.ERROR): + job_run_task = asyncio.create_task(job.run(asyncio.Semaphore(), max_submit=1)) + job.started.set() + job.returncode.set_result(0) + await job_run_task + + assert "Disk synchronization failed for does/not/exist" in log_msgs + + +@pytest.mark.asyncio +async def test_when_files_in_manifest_are_not_created_an_error_is_logged( + realization: Realization, monkeypatch +): + scheduler = create_scheduler() + scheduler.wait_for_checksum = lambda: True + file_path = "does/not/exist" + error = f"Expected file {file_path} not created by forward model!" + scheduler.checksum = { + "test_runpath": { + "file": { + "path": file_path, + "error": error, + } + } + } + log_msgs = [] + job = Job(scheduler, realization) + job.started.set() + + with captured_logs(log_msgs, logging.ERROR): + job_run_task = asyncio.create_task(job.run(asyncio.Semaphore(), max_submit=1)) + job.started.set() + job.returncode.set_result(0) + await job_run_task + + assert error in log_msgs + + +@pytest.mark.usefixtures("use_tmpdir") +@pytest.mark.asyncio +async def test_when_checksums_do_not_match_a_warning_is_logged( + realization: Realization, +): + scheduler = create_scheduler() + scheduler.wait_for_checksum = lambda: True + file_path = "invalid_md5sum" + scheduler.checksum = { + "test_runpath": { + "file": { + "path": file_path, + "md5sum": "something_something_checksum", + } + } + } + # Create the file + Path(file_path).write_text("test") + + log_msgs = [] + job = Job(scheduler, realization) + job.started.set() + + with captured_logs(log_msgs, logging.WARNING): + job_run_task = asyncio.create_task(job.run(asyncio.Semaphore(), max_submit=1)) + job.started.set() + job.returncode.set_result(0) + await job_run_task + + assert f"File {file_path} checksum verification failed." in log_msgs + + +@pytest.mark.usefixtures("use_tmpdir") +@pytest.mark.asyncio +async def test_when_no_checksum_info_is_received_a_warning_is_logged( + realization: Realization, mocker +): + scheduler = create_scheduler() + scheduler.wait_for_checksum = lambda: True + scheduler.checksum = {} + # Create the file + + log_msgs = [] + job = Job(scheduler, realization) + job.started.set() + + # Mock asyncio.sleep to fast-forward time + mocker.patch("asyncio.sleep", return_value=None) + + with captured_logs(log_msgs, logging.WARNING): + job_run_task = asyncio.create_task(job.run(asyncio.Semaphore(), max_submit=1)) + job.started.set() + job.returncode.set_result(0) + await job_run_task + + assert ( + f"Checksum information not received for {realization.run_arg.runpath}" + in log_msgs + )