Skip to content

Commit

Permalink
Remove synced _send from Client
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 16, 2024
1 parent b407664 commit ec253ea
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 39 deletions.
20 changes: 3 additions & 17 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import zmq
import zmq.asyncio

from _ert.async_utils import new_event_loop

logger = logging.getLogger(__name__)


Expand All @@ -31,14 +29,6 @@ class Client:
DEFAULT_ACK_TIMEOUT = 5
_receiver_task: asyncio.Task[None] | None

def __enter__(self) -> Self:
self.loop.run_until_complete(self.__aenter__())
return self

def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
self.loop.run_until_complete(self.__aexit__(exc_type, exc_value, exc_traceback))
self.loop.close()

async def __aenter__(self) -> Self:
await self.connect()
return self
Expand All @@ -47,7 +37,7 @@ async def __aexit__(
self, exc_type: Any, exc_value: Any, exc_traceback: Any
) -> None:
try:
await self._send(DISCONNECT_MSG)
await self.send(DISCONNECT_MSG)
except ClientConnectionError:
logger.error("No ack for dealer disconnection. Connection is down!")
finally:
Expand Down Expand Up @@ -89,23 +79,19 @@ def __init__(
self.socket.curve_publickey = client_public
self.socket.curve_serverkey = token.encode("utf-8")

self.loop = new_event_loop()
self._receiver_task = None

async def connect(self) -> None:
self.socket.connect(self.url)
await self._term_receiver_task()
self._receiver_task = asyncio.create_task(self._receiver())
try:
await self._send(CONNECT_MSG, retries=1)
await self.send(CONNECT_MSG, retries=1)
except ClientConnectionError:
await self._term_receiver_task()
self.term()
raise

def send(self, message: str, retries: int | None = None) -> None:
self.loop.run_until_complete(self._send(message, retries))

async def process_message(self, msg: str) -> None:
raise NotImplementedError("Only monitor can receive messages!")

Expand All @@ -124,7 +110,7 @@ async def _receiver(self) -> None:
await asyncio.sleep(1)
self.socket.connect(self.url)

async def _send(self, message: str, retries: int | None = None) -> None:
async def send(self, message: str, retries: int | None = None) -> None:
self._ack_event.clear()

backoff = 1
Expand Down
2 changes: 1 addition & 1 deletion src/_ert/forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def publisher():
> self._finished_event_timeout
):
break
await client._send(event_to_json(event), self._max_retries)
await client.send(event_to_json(event), self._max_retries)
event = None
except asyncio.CancelledError:
return
Expand Down
2 changes: 1 addition & 1 deletion src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ async def send_event(
retries: int = 10,
) -> None:
async with Client(url, token) as client:
await client._send(event_to_json(event), retries)
await client.send(event_to_json(event), retries)

def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]:
def event_builder(status: str) -> Event:
Expand Down
4 changes: 2 additions & 2 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ async def signal_cancel(self) -> None:
logger.debug(f"monitor-{self._id} asking server to cancel...")

cancel_event = EEUserCancel(monitor=self._id)
await self._send(event_to_json(cancel_event))
await self.send(event_to_json(cancel_event))
logger.debug(f"monitor-{self._id} asked server to cancel")

async def signal_done(self) -> None:
await self._event_queue.put(Monitor._sentinel)
logger.debug(f"monitor-{self._id} informing server monitor is done...")

done_event = EEUserDone(monitor=self._id)
await self._send(event_to_json(done_event))
await self.send(event_to_json(done_event))
logger.debug(f"monitor-{self._id} informed server monitor is done")

async def track(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def test_successful_sending(unused_tcp_port):
messages_c1 = ["test_1", "test_2", "test_3"]
async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as c1:
for message in messages_c1:
await c1._send(message)
await c1.send(message)

for msg in messages_c1:
assert msg in mock_server.messages
Expand All @@ -38,7 +38,7 @@ async def test_retry(unused_tcp_port):
):
for message in messages_c1:
try:
await c1._send(message, retries=1)
await c1.send(message, retries=1)
except ClientConnectionError:
client_connection_error_set = True
mock_server.signal(0)
Expand Down
32 changes: 16 additions & 16 deletions tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ async def test_restarted_jobs_do_not_have_error_msgs(evaluator_to_use):
fm_step="0",
current_memory_usage=1000,
)
await dispatch._send(event_to_json(event))
await dispatch.send(event_to_json(event))

event = ForwardModelStepFailure(
ensemble=evaluator.ensemble.id_,
real="0",
fm_step="0",
error_msg="error",
)
await dispatch._send(event_to_json(event))
await dispatch.send(event_to_json(event))

def is_completed_snapshot(snapshot: EnsembleSnapshot) -> bool:
try:
Expand Down Expand Up @@ -211,7 +211,7 @@ def is_completed_snapshot(snapshot: EnsembleSnapshot) -> bool:
fm_step="0",
current_memory_usage=1000,
)
await dispatch._send(event_to_json(event))
await dispatch.send(event_to_json(event))

# reconnect new monitor
async with Monitor(config_info) as new_monitor:
Expand Down Expand Up @@ -262,23 +262,23 @@ async def test_new_monitor_can_pick_up_where_we_left_off(evaluator_to_use):
fm_step="0",
current_memory_usage=1000,
)
await dispatch1._send(event_to_json(event))
await dispatch1.send(event_to_json(event))
# second dispatch endpoint client informs that forward model 0 is running
event = ForwardModelStepRunning(
ensemble=evaluator.ensemble.id_,
real="1",
fm_step="0",
current_memory_usage=1000,
)
await dispatch2._send(event_to_json(event))
await dispatch2.send(event_to_json(event))
# second dispatch endpoint client informs that forward model 1 is running
event = ForwardModelStepRunning(
ensemble=evaluator.ensemble.id_,
real="1",
fm_step="1",
current_memory_usage=1000,
)
await dispatch2._send(event_to_json(event))
await dispatch2.send(event_to_json(event))

final_snapshot = EnsembleSnapshot()

Expand Down Expand Up @@ -319,12 +319,12 @@ def check_if_all_fm_running(snapshot: EnsembleSnapshot) -> bool:
fm_step="0",
current_memory_usage=1000,
)
await dispatch2._send(event_to_json(event))
await dispatch2.send(event_to_json(event))
# second dispatch endpoint client informs that job 1 is failed
event = ForwardModelStepFailure(
ensemble=evaluator.ensemble.id_, real="1", fm_step="1", error_msg="error"
)
await dispatch2._send(event_to_json(event))
await dispatch2.send(event_to_json(event))

def check_if_final_snapshot_is_complete(final_snapshot: EnsembleSnapshot) -> bool:
try:
Expand Down Expand Up @@ -391,31 +391,31 @@ async def test_dispatch_endpoint_clients_can_connect_and_monitor_can_shut_down_e
fm_step="0",
current_memory_usage=1000,
)
await dispatch1._send(event_to_json(event))
await dispatch1.send(event_to_json(event))
# second dispatch endpoint client informs that real 1 fm 0 is running
event = ForwardModelStepRunning(
ensemble=evaluator.ensemble.id_,
real="1",
fm_step="0",
current_memory_usage=1000,
)
await dispatch2._send(event_to_json(event))
await dispatch2.send(event_to_json(event))
# second dispatch endpoint client informs that real 1 fm 0 is done
event = ForwardModelStepSuccess(
ensemble=evaluator.ensemble.id_,
real="1",
fm_step="0",
current_memory_usage=1000,
)
await dispatch2._send(event_to_json(event))
await dispatch2.send(event_to_json(event))
# second dispatch endpoint client informs that real 1 fm 1 is failed
event = ForwardModelStepFailure(
ensemble=evaluator.ensemble.id_,
real="1",
fm_step="1",
error_msg="error",
)
await dispatch2._send(event_to_json(event))
await dispatch2.send(event_to_json(event))

event = await anext(events)
snapshot = EnsembleSnapshot.from_nested_dict(event.snapshot)
Expand Down Expand Up @@ -478,17 +478,17 @@ async def test_ensure_multi_level_events_in_order(evaluator_to_use):
assert type(snapshot_event) is EESnapshot
async with Client(url, token=token) as dispatch:
event = EnsembleStarted(ensemble=evaluator.ensemble.id_)
await dispatch._send(event_to_json(event))
await dispatch.send(event_to_json(event))
event = RealizationSuccess(
ensemble=evaluator.ensemble.id_, real="0", queue_event_type=""
)
await dispatch._send(event_to_json(event))
await dispatch.send(event_to_json(event))
event = RealizationSuccess(
ensemble=evaluator.ensemble.id_, real="1", queue_event_type=""
)
await dispatch._send(event_to_json(event))
await dispatch.send(event_to_json(event))
event = EnsembleSucceeded(ensemble=evaluator.ensemble.id_)
await dispatch._send(event_to_json(event))
await dispatch.send(event_to_json(event))

await monitor.signal_done()

Expand Down

0 comments on commit ec253ea

Please sign in to comment.