diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 62b18c590ad..c767483d72d 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -23,9 +23,8 @@ class ClientConnectionClosedOK(Exception): class Client: - DEFAULT_MAX_RETRIES = 10 - DEFAULT_TIMEOUT_MULTIPLIER = 5 - CONNECTION_TIMEOUT = 60 + DEFAULT_MAX_RETRIES = 5 + DEFAULT_ACK_TIMEOUT = 5 _receiver_task: Optional[asyncio.Task[None]] def __enter__(self) -> Self: @@ -63,10 +62,10 @@ def __init__( url: str, token: Optional[str] = None, cert: Optional[Union[str, bytes]] = None, - connection_timeout: float = 5.0, dealer_name: Optional[str] = None, + ack_timeout: Optional[float] = None, ) -> None: - self._connection_timeout = connection_timeout + self._ack_timeout = ack_timeout or self.DEFAULT_ACK_TIMEOUT self.url = url self.token = token @@ -80,7 +79,7 @@ def __init__( else: self.dealer_id = dealer_name self.socket.setsockopt_string(zmq.IDENTITY, self.dealer_id) - print(f"Created: {self.dealer_id=} {token=} {self._connection_timeout=}") + print(f"Created: {self.dealer_id=} {token=} {self._ack_timeout=}") if token is not None: client_public, client_secret = zmq.curve_keypair() self.socket.curve_secretkey = client_secret @@ -95,14 +94,14 @@ async def connect(self) -> None: await self._term_receiver_task() self._receiver_task = asyncio.create_task(self._receiver()) try: - await self._send("CONNECT", max_retries=1) + await self._send("CONNECT", retries=1) except ClientConnectionError: await self._term_receiver_task() self.term() raise - def send(self, message: str, max_retries: int = DEFAULT_MAX_RETRIES) -> None: - self.loop.run_until_complete(self._send(message, max_retries)) + def send(self, message: str, retries: Optional[int] = None) -> None: + self.loop.run_until_complete(self._send(message, retries)) async def process_message(self, msg: str) -> None: pass @@ -122,17 +121,17 @@ async def _receiver(self) -> None: await asyncio.sleep(1) self.socket.connect(self.url) - async def _send(self, message: str, max_retries: int = DEFAULT_MAX_RETRIES) -> None: + async def _send(self, message: str, retries: Optional[int] = None) -> None: self._ack_event.clear() backoff = 1 - - while max_retries > 0: + retries = retries or self.DEFAULT_MAX_RETRIES + while retries > 0: try: await self.socket.send_multipart([b"", message.encode("utf-8")]) try: await asyncio.wait_for( - self._ack_event.wait(), timeout=self._connection_timeout + self._ack_event.wait(), timeout=self._ack_timeout ) return except asyncio.TimeoutError: @@ -149,12 +148,11 @@ async def _send(self, message: str, max_retries: int = DEFAULT_MAX_RETRIES) -> N self.term() raise - max_retries -= 1 - if max_retries > 0: - logger.info(f"Retrying... ({max_retries} attempts left)") + retries -= 1 + if retries > 0: + logger.info(f"Retrying... ({retries} attempts left)") await asyncio.sleep(backoff) backoff = min(backoff * 2, 10) # Exponential backoff - raise ClientConnectionError( - f"{self.dealer_id} Failed to send {message=} after retries." + f"{self.dealer_id} Failed to send {message=} after {retries=}" ) diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index 15052f27249..5feac6fc268 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -3,6 +3,7 @@ import asyncio import logging import queue +import time from pathlib import Path from typing import Final, Union @@ -27,7 +28,7 @@ Start, ) from _ert.forward_model_runner.reporting.statemachine import StateMachine -from _ert.threading import ErtThread +from _ert.threading import ErtThread, threading logger = logging.getLogger(__name__) @@ -54,7 +55,14 @@ class Event(Reporter): _sentinel: Final = EventSentinel() - def __init__(self, evaluator_url, token=None, cert_path=None): + def __init__( + self, + evaluator_url, + token=None, + cert_path=None, + ack_timeout=None, + max_retries=None, + ): self._evaluator_url = evaluator_url self._token = token if cert_path is not None: @@ -73,10 +81,13 @@ def __init__(self, evaluator_url, token=None, cert_path=None): self._real_id = None self._event_queue: queue.Queue[events.Event | EventSentinel] = queue.Queue() self._event_publisher_thread = ErtThread(target=self._event_publisher) - self._done = False + self._done = threading.Event() + self._ack_timeout = ack_timeout + self._max_retries = max_retries def stop(self): self._event_queue.put(Event._sentinel) + self._done.set() if self._event_publisher_thread.is_alive(): self._event_publisher_thread.join() @@ -86,19 +97,31 @@ async def publisher(): url=self._evaluator_url, token=self._token, cert=self._cert, + ack_timeout=self._ack_timeout, ) as client: + event = None + start_time = None while True: try: - event = self._event_queue.get() - if event is self._sentinel: + if self._done.is_set() and start_time is None: + start_time = time.time() + if event is None: + event = self._event_queue.get() + if event is self._sentinel: + break + if start_time and (time.time() - start_time) > 2: break - await client._send(event_to_json(event)) + await client._send(event_to_json(event), self._max_retries) + event = None except asyncio.CancelledError: return except ClientConnectionError as exc: logger.error(f"Failed to send event: {exc}") - asyncio.run(publisher()) + try: + asyncio.run(publisher()) + except ClientConnectionError as exc: + raise ClientConnectionError("Couldn't connect to evaluator") from exc def report(self, msg): self._statemachine.transition(msg) @@ -162,6 +185,7 @@ def _job_handler(self, msg: Union[Start, Running, Exited]): def _finished_handler(self, _): self._event_queue.put(Event._sentinel) + self._done.set() if self._event_publisher_thread.is_alive(): self._event_publisher_thread.join() diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index d706f99edeb..cc7c8de0307 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -204,7 +204,7 @@ async def send_event( retries: int = 10, ) -> None: async with Client(url, token, cert) as client: - await client._send(event_to_json(event), max_retries=retries) + await client._send(event_to_json(event), retries) def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]: def event_builder(status: str) -> Event: diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 07200b0a940..0d208ec30f6 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -262,7 +262,6 @@ async def listen_for_messages(self) -> None: if sender.startswith("client"): await self.handle_client(dealer, frame) elif sender.startswith("dispatch"): - # await self._router_socket.send_multipart([dealer, b"", b"ACK"]) await self.handle_dispatch(dealer, frame) else: logger.info(f"Connection attempt to unknown sender: {sender}.") diff --git a/tests/ert/ui_tests/cli/test_cli.py b/tests/ert/ui_tests/cli/test_cli.py index 5cc45df200e..dc402457793 100644 --- a/tests/ert/ui_tests/cli/test_cli.py +++ b/tests/ert/ui_tests/cli/test_cli.py @@ -952,9 +952,8 @@ def test_tracking_missing_ecl(monkeypatch, tmp_path, caplog): def test_that_connection_errors_do_not_effect_final_result( monkeypatch: pytest.MonkeyPatch, ): - monkeypatch.setattr(Client, "DEFAULT_MAX_RETRIES", 0) - monkeypatch.setattr(Client, "DEFAULT_TIMEOUT_MULTIPLIER", 0) - monkeypatch.setattr(Client, "CONNECTION_TIMEOUT", 1) + monkeypatch.setattr(Client, "DEFAULT_MAX_RETRIES", 1) + monkeypatch.setattr(Client, "DEFAULT_ACK_TIMEOUT", 1) monkeypatch.setattr(EnsembleEvaluator, "CLOSE_SERVER_TIMEOUT", 0.01) monkeypatch.setattr(Job, "DEFAULT_CHECKSUM_TIMEOUT", 0) diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py index a545bca8bdf..86d9d8dc46c 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py @@ -14,7 +14,7 @@ def test_invalid_server(): with ( pytest.raises(ClientConnectionError), - Client(url, connection_timeout=1.0), + Client(url, ack_timeout=1.0), ): pass @@ -31,7 +31,8 @@ async def test_successful_sending(unused_tcp_port): await server_started.wait() messages_c1 = ["test_1", "test_2", "test_3"] async with Client(url) as c1: - await c1._send(messages_c1) + for message in messages_c1: + await c1._send(message) await server_task diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py index 78c7a4748ec..fc60ada420c 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py @@ -24,7 +24,7 @@ async def async_zmq_server(port, handler): async def test_no_connection_established(make_ee_config): ee_config = make_ee_config() monitor = Monitor(ee_config.get_connection_info()) - monitor._connection_timeout = 0.1 + monitor._ack_timeout = 0.1 with pytest.raises(ClientConnectionError): async with monitor: pass diff --git a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py index 0cdd5eec24c..8fe8a4c6e73 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py +++ b/tests/ert/unit_tests/forward_model_runner/test_event_reporter.py @@ -1,6 +1,6 @@ import os +import queue import time -from unittest.mock import patch import pytest @@ -11,7 +11,6 @@ ForwardModelStepSuccess, event_from_json, ) -from _ert.forward_model_runner.client import ClientConnectionError from _ert.forward_model_runner.forward_model_step import ForwardModelStep from _ert.forward_model_runner.reporting import Event from _ert.forward_model_runner.reporting.message import ( @@ -41,7 +40,7 @@ def test_report_with_successful_start_message_argument(unused_tcp_port): {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 ) lines = [] - with mock_zmq_thread(host, unused_tcp_port, lines): + with mock_zmq_thread(unused_tcp_port, lines): reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Start(fmstep1)) reporter.report(Finish()) @@ -66,7 +65,7 @@ def test_report_with_failed_start_message_argument(unused_tcp_port): ) lines = [] - with mock_zmq_thread(host, unused_tcp_port, lines): + with mock_zmq_thread(unused_tcp_port, lines): reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) msg = Start(fmstep1).with_error("massive_failure") @@ -89,7 +88,7 @@ def test_report_with_successful_exit_message_argument(unused_tcp_port): ) lines = [] - with mock_zmq_thread(host, unused_tcp_port, lines): + with mock_zmq_thread(unused_tcp_port, lines): reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Exited(fmstep1, 0)) reporter.report(Finish().with_error("failed")) @@ -108,7 +107,7 @@ def test_report_with_failed_exit_message_argument(unused_tcp_port): ) lines = [] - with mock_zmq_thread(host, unused_tcp_port, lines): + with mock_zmq_thread(unused_tcp_port, lines): reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Exited(fmstep1, 1).with_error("massive_failure")) reporter.report(Finish()) @@ -128,7 +127,7 @@ def test_report_with_running_message_argument(unused_tcp_port): ) lines = [] - with mock_zmq_thread(host, unused_tcp_port, lines): + with mock_zmq_thread(unused_tcp_port, lines): reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Finish()) @@ -149,7 +148,7 @@ def test_report_only_job_running_for_successful_run(unused_tcp_port): ) lines = [] - with mock_zmq_thread(host, unused_tcp_port, lines): + with mock_zmq_thread(unused_tcp_port, lines): reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Finish()) @@ -166,7 +165,7 @@ def test_report_with_failed_finish_message_argument(unused_tcp_port): ) lines = [] - with mock_zmq_thread(host, unused_tcp_port, lines): + with mock_zmq_thread(unused_tcp_port, lines): reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) reporter.report(Finish().with_error("massive_failure")) @@ -188,7 +187,6 @@ def test_report_inconsistent_events(unused_tcp_port): reporter.report(Finish()) -@pytest.mark.integration_test def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): # this is to show when the reporter fails ert won't crash nor # staying hanging but instead finishes up the job; @@ -196,129 +194,108 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port): # also assert reporter._timeout_timestamp is None # meaning Finish event initiated _timeout and timeout was reached # which then sets _timeout_timestamp=None - mock_send_retry_time = 2 - - def mock_send(msg): - time.sleep(mock_send_retry_time) - raise ClientConnectionError("Sending failed!") host = "localhost" url = f"tcp://{host}:{unused_tcp_port}" - reporter = Event(evaluator_url=url) - reporter._reporter_timeout = 4 - fmstep1 = ForwardModelStep( - {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 - ) lines = [] - with mock_zmq_thread(host, unused_tcp_port, lines): - with patch( - "_ert.forward_model_runner.client.Client.send", lambda x, y: mock_send(y) - ): - reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) - reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) - reporter.report(Finish()) + signal_queue = queue.Queue() + with mock_zmq_thread(unused_tcp_port, lines, signal_queue): + reporter = Event(evaluator_url=url, ack_timeout=1, max_retries=1) + fmstep1 = ForwardModelStep( + {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 + ) + + signal_queue.put(1) # prevent router to receive messages + reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) + reporter.report(Finish()) if reporter._event_publisher_thread.is_alive(): reporter._event_publisher_thread.join() - assert reporter._done + assert reporter._done.is_set() assert len(lines) == 0, "expected 0 Job running messages" - # TODO refactor or remove, zmq handles reconnection automatically - # @pytest.mark.integration_test - # @pytest.mark.flaky(reruns=5) - # @pytest.mark.skipif( - # sys.platform.startswith("darwin"), reason="Performance can be flaky" - # ) - # def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): - # # this is to show when the reporter fails but reconnects - # # reporter still manages to send events and completes fine - # # see assert reporter._timeout_timestamp is not None - # # meaning Finish event initiated _timeout but timeout wasn't reached since - # # it finished succesfully - # mock_send_retry_time = 0.1 - - # def send_func(msg): - # time.sleep(mock_send_retry_time) - # raise ClientConnectionError("Sending failed!") - - # host = "localhost" - # url = f"tcp://{host}:{unused_tcp_port}" - # reporter = Event(evaluator_url=url) - # fmstep1 = ForwardModelStep( - # {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 - # ) - # lines = [] - # with mock_zmq_thread(host, unused_tcp_port, lines): - # with patch("_ert.forward_model_runner.client.Client.send") as patched_send: - # patched_send.side_effect = send_func - - # reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - # reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - # reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) - # reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10))) - - # _wait_until( - # condition=lambda: patched_send.call_count >= 1, - # timeout=10, - # fail_msg="10 seconds should be sufficient to send three events", - # ) - - # # reconnect and continue sending events - # reporter.report(Finish()) - # if reporter._event_publisher_thread.is_alive(): - # reporter._event_publisher_thread.join() - # assert reporter._done - # assert len(lines) == 3, "expected 3 Job running messages" - - # @pytest.mark.integration_test - # def test_report_with_closed_received_exiting_gracefully(unused_tcp_port): - # # Whenever the receiver end closes the connection, a ConnectionClosedOK is raised - # # The reporter should exit the publisher thread gracefully and not send any - # # more events - # mock_send_retry_time = 3 - - # def mock_send(msg): - # time.sleep(mock_send_retry_time) - # raise ClientConnectionClosedOK("Connection Closed") - - # host = "localhost" - # url = f"tcp://{host}:{unused_tcp_port}" - # reporter = Event(evaluator_url=url) - # fmstep1 = ForwardModelStep( - # {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 - # ) - # lines = [] - # with mock_zmq_thread(host, unused_tcp_port, lines): - # reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) - # reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) - # reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) - - # # sleep until both Running events have been received - # _wait_until( - # condition=lambda: len(lines) == 2, - # timeout=10, - # fail_msg="Should not take 10 seconds to send two events", - # ) - - # with patch( - # "_ert.forward_model_runner.client.Client.send", lambda x, y: mock_send(y) - # ): - # reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10))) - # # Make sure the publisher thread exits because it got - # # ClientConnectionClosedOK. If it hangs it could indicate that the - # # exception is not caught/handled correctly - # if reporter._event_publisher_thread.is_alive(): - # reporter._event_publisher_thread.join() - - # reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=400, rss=10))) - # reporter.report(Finish()) - - # # set _stop_timestamp was not set to None since the reporter finished on time - # assert reporter._timeout_timestamp is not None - - # # The Running(fmstep1, 300, 10) is popped from the queue, but never sent. - # # The following Running is added to queue along with the sentinel - # assert reporter._event_queue.qsize() == 2 - # # None of the messages after ClientConnectionClosedOK was raised, has been sent - # assert len(lines) == 2, "expected 2 Job running messages" + +def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port): + # this is to show when the reporter fails but reconnects + # reporter still manages to send events and completes fine + # see assert reporter._timeout_timestamp is not None + # meaning Finish event initiated _timeout but timeout wasn't reached since + # it finished succesfully + + host = "localhost" + url = f"tcp://{host}:{unused_tcp_port}" + lines = [] + signal_queue = queue.Queue() + with mock_zmq_thread(unused_tcp_port, lines, signal_queue): + reporter = Event(evaluator_url=url, ack_timeout=1, max_retries=1) + fmstep1 = ForwardModelStep( + {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 + ) + + signal_queue.put(1) # prevent router to receive messages + reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) + reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=1100, rss=10))) + signal_queue.put(0) # enable router to receive messages + reporter.report(Finish()) + if reporter._event_publisher_thread.is_alive(): + reporter._event_publisher_thread.join() + assert reporter._done.is_set() + assert len(lines) == 3, "expected 3 Job running messages" + + +# REFACTOR maybe we don't this anymore +# @pytest.mark.integration_test +# def test_report_with_closed_received_exiting_gracefully(unused_tcp_port): +# mock_send_retry_time = 3 + +# def mock_send(msg): +# time.sleep(mock_send_retry_time) +# raise ClientConnectionClosedOK("Connection Closed") + +# host = "localhost" +# url = f"tcp://{host}:{unused_tcp_port}" +# reporter = Event(evaluator_url=url) +# fmstep1 = ForwardModelStep( +# {"name": "fmstep1", "stdout": "stdout", "stderr": "stderr"}, 0 +# ) +# lines = [] +# with mock_zmq_thread(unused_tcp_port, lines): +# reporter.report(Init([fmstep1], 1, 19, ens_id="ens_id", real_id=0)) +# reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=100, rss=10))) +# reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=200, rss=10))) + +# # sleep until both Running events have been received +# _wait_until( +# condition=lambda: len(lines) == 2, +# timeout=10, +# fail_msg="Should not take 10 seconds to send two events", +# ) + +# with patch( +# "_ert.forward_model_runner.client.Client.send", +# lambda x, y: mock_send(y), +# ): +# reporter.report( +# Running(fmstep1, ProcessTreeStatus(max_rss=300, rss=10)) +# ) +# # Make sure the publisher thread exits because it got +# # ClientConnectionClosedOK. If it hangs it could indicate that the +# # exception is not caught/handled correctly +# if reporter._event_publisher_thread.is_alive(): +# reporter._event_publisher_thread.join() + +# reporter.report(Running(fmstep1, ProcessTreeStatus(max_rss=400, rss=10))) +# reporter.report(Finish()) + +# # set _stop_timestamp was not set to None since the reporter finished on time +# assert reporter._timeout_timestamp is not None + +# # The Running(fmstep1, 300, 10) is popped from the queue, but never sent. +# # The following Running is added to queue along with the sentinel +# assert reporter._event_queue.qsize() == 2 +# # None of the messages after ClientConnectionClosedOK was raised, has been sent +# assert len(lines) == 2, "expected 2 Job running messages" diff --git a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py index c9e77c25515..a39632d0800 100644 --- a/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py +++ b/tests/ert/unit_tests/forward_model_runner/test_job_dispatch.py @@ -316,7 +316,7 @@ def create_jobs_file_after_lock(): (tmp_path / JOBS_FILE).write_text(jobs_json) lock.release() - with mock_zmq_thread("localhost", unused_tcp_port, []): + with mock_zmq_thread(unused_tcp_port, []): thread = ErtThread(target=create_jobs_file_after_lock) thread.start() main(args=["script.py", str(tmp_path)]) @@ -345,7 +345,6 @@ def test_setup_reporters(is_interactive_run, ens_id): @pytest.mark.usefixtures("use_tmpdir") def test_job_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port): - host = "localhost" port = unused_tcp_port jobs_json = json.dumps( {"ens_id": "_id_", "dispatch_url": f"tcp://localhost:{port}"} @@ -363,7 +362,7 @@ def test_job_dispatch_kills_itself_after_unsuccessful_job(unused_tcp_port): ] mock_getpgid.return_value = 17 - with mock_zmq_thread(host, port, []): + with mock_zmq_thread(port, []): main(["script.py"]) mock_killpg.assert_called_with(17, signal.SIGKILL) diff --git a/tests/ert/utils.py b/tests/ert/utils.py index 815dcacac5a..b1cdf58bf42 100644 --- a/tests/ert/utils.py +++ b/tests/ert/utils.py @@ -2,6 +2,7 @@ import asyncio import contextlib +import queue import time from pathlib import Path from typing import TYPE_CHECKING @@ -63,13 +64,12 @@ def wait_until(func, interval=0.5, timeout=30): async def async_mock_zmq_server(messages, port, server_started): async def _handler(router_socket): while True: - dealer, __, *frames = await router_socket.recv_multipart() + dealer, __, frame = await router_socket.recv_multipart() await router_socket.send_multipart([dealer, b"", b"ACK"]) - for frame in frames: - raw_msg = frame.decode("utf-8") - messages.append(raw_msg) - if raw_msg == "DISCONNECT": - return + raw_msg = frame.decode("utf-8") + messages.append(raw_msg) + if raw_msg == "DISCONNECT": + return zmq_context = zmq.asyncio.Context() # type: ignore router_socket = zmq_context.socket(zmq.ROUTER) @@ -81,33 +81,36 @@ async def _handler(router_socket): @contextlib.contextmanager -def mock_zmq_thread(host, port, messages): +def mock_zmq_thread(port, messages, signal_queue=None): loop = None handler_task = None - def mock_zmq_server(messages, port): + def mock_zmq_server(messages, port, signal_queue=None): nonlocal loop, handler_task loop = asyncio.new_event_loop() async def _handler(router_socket): - nonlocal messages + nonlocal messages, signal_queue + signal_value = 0 while True: try: - dealer, __, *frames = await router_socket.recv_multipart() - await router_socket.send_multipart([dealer, b"", b"ACK"]) - for frame in frames: - raw_msg = frame.decode("utf-8") - print(f"{raw_msg} from {dealer}") - if raw_msg not in ["CONNECT", "DISCONNECT"]: - messages.append(raw_msg) + dealer, __, frame = await router_socket.recv_multipart() + if signal_queue: + with contextlib.suppress(queue.Empty): + signal_value = signal_queue.get(timeout=0.1) + print(f"{dealer=} {frame=} {signal_value=}") + if frame in [b"CONNECT", b"DISCONNECT"] or signal_value != 1: + await router_socket.send_multipart([dealer, b"", b"ACK"]) + if frame not in [b"CONNECT", b"DISCONNECT"]: + messages.append(frame.decode("utf-8")) except asyncio.CancelledError: break async def _run_server(): + nonlocal handler_task zmq_context = zmq.asyncio.Context() # type: ignore router_socket = zmq_context.socket(zmq.ROUTER) router_socket.bind(f"tcp://*:{port}") - nonlocal handler_task handler_task = asyncio.create_task(_handler(router_socket)) await handler_task router_socket.close() @@ -116,17 +119,12 @@ async def _run_server(): loop.close() mock_zmq_thread = ErtThread( - target=lambda: mock_zmq_server(messages, port), + target=lambda: mock_zmq_server(messages, port, signal_queue), ) mock_zmq_thread.start() try: yield - # Make sure to join the thread even if an exception occurs finally: - # url = f"tcp://{host}:{port}" - # with Client(url) as client: - # client.send("stop") - # # Cancel the handler task explicitly print(f"these are the final {messages=}") if handler_task and not handler_task.done(): loop.call_soon_threadsafe(handler_task.cancel)