From ec253eae64f725810b30ce36358533fbaccd308a Mon Sep 17 00:00:00 2001 From: xjules Date: Mon, 16 Dec 2024 13:34:34 +0100 Subject: [PATCH] Remove synced _send from Client --- src/_ert/forward_model_runner/client.py | 20 ++---------- .../forward_model_runner/reporting/event.py | 2 +- src/ert/ensemble_evaluator/_ensemble.py | 2 +- src/ert/ensemble_evaluator/monitor.py | 4 +-- .../test_ensemble_client.py | 4 +-- .../test_ensemble_evaluator.py | 32 +++++++++---------- 6 files changed, 25 insertions(+), 39 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index a9db9c7c3e3..3869f6bb89a 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -8,8 +8,6 @@ import zmq import zmq.asyncio -from _ert.async_utils import new_event_loop - logger = logging.getLogger(__name__) @@ -31,14 +29,6 @@ class Client: DEFAULT_ACK_TIMEOUT = 5 _receiver_task: asyncio.Task[None] | None - def __enter__(self) -> Self: - self.loop.run_until_complete(self.__aenter__()) - return self - - def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None: - self.loop.run_until_complete(self.__aexit__(exc_type, exc_value, exc_traceback)) - self.loop.close() - async def __aenter__(self) -> Self: await self.connect() return self @@ -47,7 +37,7 @@ async def __aexit__( self, exc_type: Any, exc_value: Any, exc_traceback: Any ) -> None: try: - await self._send(DISCONNECT_MSG) + await self.send(DISCONNECT_MSG) except ClientConnectionError: logger.error("No ack for dealer disconnection. Connection is down!") finally: @@ -89,7 +79,6 @@ def __init__( self.socket.curve_publickey = client_public self.socket.curve_serverkey = token.encode("utf-8") - self.loop = new_event_loop() self._receiver_task = None async def connect(self) -> None: @@ -97,15 +86,12 @@ async def connect(self) -> None: await self._term_receiver_task() self._receiver_task = asyncio.create_task(self._receiver()) try: - await self._send(CONNECT_MSG, retries=1) + await self.send(CONNECT_MSG, retries=1) except ClientConnectionError: await self._term_receiver_task() self.term() raise - def send(self, message: str, retries: int | None = None) -> None: - self.loop.run_until_complete(self._send(message, retries)) - async def process_message(self, msg: str) -> None: raise NotImplementedError("Only monitor can receive messages!") @@ -124,7 +110,7 @@ async def _receiver(self) -> None: await asyncio.sleep(1) self.socket.connect(self.url) - async def _send(self, message: str, retries: int | None = None) -> None: + async def send(self, message: str, retries: int | None = None) -> None: self._ack_event.clear() backoff = 1 diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index 8ca50302131..5e804246a0c 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -110,7 +110,7 @@ async def publisher(): > self._finished_event_timeout ): break - await client._send(event_to_json(event), self._max_retries) + await client.send(event_to_json(event), self._max_retries) event = None except asyncio.CancelledError: return diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index 10087770623..e1577287975 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -194,7 +194,7 @@ async def send_event( retries: int = 10, ) -> None: async with Client(url, token) as client: - await client._send(event_to_json(event), retries) + await client.send(event_to_json(event), retries) def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]: def event_builder(status: str) -> Event: diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 90e8fc274d8..d55a50b9661 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -49,7 +49,7 @@ async def signal_cancel(self) -> None: logger.debug(f"monitor-{self._id} asking server to cancel...") cancel_event = EEUserCancel(monitor=self._id) - await self._send(event_to_json(cancel_event)) + await self.send(event_to_json(cancel_event)) logger.debug(f"monitor-{self._id} asked server to cancel") async def signal_done(self) -> None: @@ -57,7 +57,7 @@ async def signal_done(self) -> None: logger.debug(f"monitor-{self._id} informing server monitor is done...") done_event = EEUserDone(monitor=self._id) - await self._send(event_to_json(done_event)) + await self.send(event_to_json(done_event)) logger.debug(f"monitor-{self._id} informed server monitor is done") async def track( diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py index 3a87cdac364..75501453140 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py @@ -21,7 +21,7 @@ async def test_successful_sending(unused_tcp_port): messages_c1 = ["test_1", "test_2", "test_3"] async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as c1: for message in messages_c1: - await c1._send(message) + await c1.send(message) for msg in messages_c1: assert msg in mock_server.messages @@ -38,7 +38,7 @@ async def test_retry(unused_tcp_port): ): for message in messages_c1: try: - await c1._send(message, retries=1) + await c1.send(message, retries=1) except ClientConnectionError: client_connection_error_set = True mock_server.signal(0) diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py index 05bad359c40..4201b9468fd 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py @@ -174,7 +174,7 @@ async def test_restarted_jobs_do_not_have_error_msgs(evaluator_to_use): fm_step="0", current_memory_usage=1000, ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = ForwardModelStepFailure( ensemble=evaluator.ensemble.id_, @@ -182,7 +182,7 @@ async def test_restarted_jobs_do_not_have_error_msgs(evaluator_to_use): fm_step="0", error_msg="error", ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) def is_completed_snapshot(snapshot: EnsembleSnapshot) -> bool: try: @@ -211,7 +211,7 @@ def is_completed_snapshot(snapshot: EnsembleSnapshot) -> bool: fm_step="0", current_memory_usage=1000, ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) # reconnect new monitor async with Monitor(config_info) as new_monitor: @@ -262,7 +262,7 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use): fm_step="0", current_memory_usage=1000, ) - await dispatch1._send(event_to_json(event)) + await dispatch1.send(event_to_json(event)) # second dispatch endpoint client informs that forward model 0 is running event = ForwardModelStepRunning( ensemble=evaluator.ensemble.id_, @@ -270,7 +270,7 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use): fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that forward model 1 is running event = ForwardModelStepRunning( ensemble=evaluator.ensemble.id_, @@ -278,7 +278,7 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use): fm_step="1", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) final_snapshot = EnsembleSnapshot() @@ -319,12 +319,12 @@ def check_if_all_fm_running(snapshot: EnsembleSnapshot) -> bool: fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that job 1 is failed event = ForwardModelStepFailure( ensemble=evaluator.ensemble.id_, real="1", fm_step="1", error_msg="error" ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) def check_if_final_snapshot_is_complete(final_snapshot: EnsembleSnapshot) -> bool: try: @@ -391,7 +391,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="0", current_memory_usage=1000, ) - await dispatch1._send(event_to_json(event)) + await dispatch1.send(event_to_json(event)) # second dispatch endpoint client informs that real 1 fm 0 is running event = ForwardModelStepRunning( ensemble=evaluator.ensemble.id_, @@ -399,7 +399,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that real 1 fm 0 is done event = ForwardModelStepSuccess( ensemble=evaluator.ensemble.id_, @@ -407,7 +407,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="0", current_memory_usage=1000, ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) # second dispatch endpoint client informs that real 1 fm 1 is failed event = ForwardModelStepFailure( ensemble=evaluator.ensemble.id_, @@ -415,7 +415,7 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e fm_step="1", error_msg="error", ) - await dispatch2._send(event_to_json(event)) + await dispatch2.send(event_to_json(event)) event = await anext(events) snapshot = EnsembleSnapshot.from_nested_dict(event.snapshot) @@ -478,17 +478,17 @@ async def test_ensure_multi_level_events_in_order(evaluator_to_use): assert type(snapshot_event) is EESnapshot async with Client(url, token=token) as dispatch: event = EnsembleStarted(ensemble=evaluator.ensemble.id_) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = RealizationSuccess( ensemble=evaluator.ensemble.id_, real="0", queue_event_type="" ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = RealizationSuccess( ensemble=evaluator.ensemble.id_, real="1", queue_event_type="" ) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) event = EnsembleSucceeded(ensemble=evaluator.ensemble.id_) - await dispatch._send(event_to_json(event)) + await dispatch.send(event_to_json(event)) await monitor.signal_done()