Skip to content

Commit

Permalink
Major rework
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 8, 2024
1 parent 786b4c4 commit 9cce793
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 187 deletions.
34 changes: 16 additions & 18 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ class ClientConnectionClosedOK(Exception):


class Client:
DEFAULT_MAX_RETRIES = 10
DEFAULT_TIMEOUT_MULTIPLIER = 5
CONNECTION_TIMEOUT = 60
DEFAULT_MAX_RETRIES = 5
DEFAULT_ACK_TIMEOUT = 5
_receiver_task: Optional[asyncio.Task[None]]

def __enter__(self) -> Self:
Expand Down Expand Up @@ -63,10 +62,10 @@ def __init__(
url: str,
token: Optional[str] = None,
cert: Optional[Union[str, bytes]] = None,
connection_timeout: float = 5.0,
dealer_name: Optional[str] = None,
ack_timeout: Optional[float] = None,
) -> None:
self._connection_timeout = connection_timeout
self._ack_timeout = ack_timeout or self.DEFAULT_ACK_TIMEOUT
self.url = url
self.token = token

Expand All @@ -80,7 +79,7 @@ def __init__(
else:
self.dealer_id = dealer_name
self.socket.setsockopt_string(zmq.IDENTITY, self.dealer_id)
print(f"Created: {self.dealer_id=} {token=} {self._connection_timeout=}")
print(f"Created: {self.dealer_id=} {token=} {self._ack_timeout=}")
if token is not None:
client_public, client_secret = zmq.curve_keypair()
self.socket.curve_secretkey = client_secret
Expand All @@ -95,14 +94,14 @@ async def connect(self) -> None:
await self._term_receiver_task()
self._receiver_task = asyncio.create_task(self._receiver())
try:
await self._send("CONNECT", max_retries=1)
await self._send("CONNECT", retries=1)
except ClientConnectionError:
await self._term_receiver_task()
self.term()
raise

def send(self, message: str, max_retries: int = DEFAULT_MAX_RETRIES) -> None:
self.loop.run_until_complete(self._send(message, max_retries))
def send(self, message: str, retries: Optional[int] = None) -> None:
self.loop.run_until_complete(self._send(message, retries))

async def process_message(self, msg: str) -> None:
pass
Expand All @@ -122,17 +121,17 @@ async def _receiver(self) -> None:
await asyncio.sleep(1)
self.socket.connect(self.url)

async def _send(self, message: str, max_retries: int = DEFAULT_MAX_RETRIES) -> None:
async def _send(self, message: str, retries: Optional[int] = None) -> None:
self._ack_event.clear()

backoff = 1

while max_retries > 0:
retries = retries or self.DEFAULT_MAX_RETRIES
while retries > 0:
try:
await self.socket.send_multipart([b"", message.encode("utf-8")])
try:
await asyncio.wait_for(
self._ack_event.wait(), timeout=self._connection_timeout
self._ack_event.wait(), timeout=self._ack_timeout
)
return
except asyncio.TimeoutError:
Expand All @@ -149,12 +148,11 @@ async def _send(self, message: str, max_retries: int = DEFAULT_MAX_RETRIES) -> N
self.term()
raise

max_retries -= 1
if max_retries > 0:
logger.info(f"Retrying... ({max_retries} attempts left)")
retries -= 1
if retries > 0:
logger.info(f"Retrying... ({retries} attempts left)")
await asyncio.sleep(backoff)
backoff = min(backoff * 2, 10) # Exponential backoff

raise ClientConnectionError(
f"{self.dealer_id} Failed to send {message=} after retries."
f"{self.dealer_id} Failed to send {message=} after {retries=}"
)
38 changes: 31 additions & 7 deletions src/_ert/forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import logging
import queue
import time
from pathlib import Path
from typing import Final, Union

Expand All @@ -27,7 +28,7 @@
Start,
)
from _ert.forward_model_runner.reporting.statemachine import StateMachine
from _ert.threading import ErtThread
from _ert.threading import ErtThread, threading

logger = logging.getLogger(__name__)

Expand All @@ -54,7 +55,14 @@ class Event(Reporter):

_sentinel: Final = EventSentinel()

def __init__(self, evaluator_url, token=None, cert_path=None):
def __init__(
self,
evaluator_url,
token=None,
cert_path=None,
ack_timeout=None,
max_retries=None,
):
self._evaluator_url = evaluator_url
self._token = token
if cert_path is not None:
Expand All @@ -73,10 +81,13 @@ def __init__(self, evaluator_url, token=None, cert_path=None):
self._real_id = None
self._event_queue: queue.Queue[events.Event | EventSentinel] = queue.Queue()
self._event_publisher_thread = ErtThread(target=self._event_publisher)
self._done = False
self._done = threading.Event()
self._ack_timeout = ack_timeout
self._max_retries = max_retries

def stop(self):
self._event_queue.put(Event._sentinel)
self._done.set()
if self._event_publisher_thread.is_alive():
self._event_publisher_thread.join()

Expand All @@ -86,19 +97,31 @@ async def publisher():
url=self._evaluator_url,
token=self._token,
cert=self._cert,
ack_timeout=self._ack_timeout,
) as client:
event = None
start_time = None
while True:
try:
event = self._event_queue.get()
if event is self._sentinel:
if self._done.is_set() and start_time is None:
start_time = time.time()
if event is None:
event = self._event_queue.get()
if event is self._sentinel:
break
if start_time and (time.time() - start_time) > 2:
break
await client._send(event_to_json(event))
await client._send(event_to_json(event), self._max_retries)
event = None
except asyncio.CancelledError:
return
except ClientConnectionError as exc:
logger.error(f"Failed to send event: {exc}")

asyncio.run(publisher())
try:
asyncio.run(publisher())
except ClientConnectionError as exc:
raise ClientConnectionError("Couldn't connect to evaluator") from exc

def report(self, msg):
self._statemachine.transition(msg)
Expand Down Expand Up @@ -162,6 +185,7 @@ def _job_handler(self, msg: Union[Start, Running, Exited]):

def _finished_handler(self, _):
self._event_queue.put(Event._sentinel)
self._done.set()
if self._event_publisher_thread.is_alive():
self._event_publisher_thread.join()

Expand Down
2 changes: 1 addition & 1 deletion src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ async def send_event(
retries: int = 10,
) -> None:
async with Client(url, token, cert) as client:
await client._send(event_to_json(event), max_retries=retries)
await client._send(event_to_json(event), retries)

def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]:
def event_builder(status: str) -> Event:
Expand Down
1 change: 0 additions & 1 deletion src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ async def listen_for_messages(self) -> None:
if sender.startswith("client"):
await self.handle_client(dealer, frame)
elif sender.startswith("dispatch"):
# await self._router_socket.send_multipart([dealer, b"", b"ACK"])
await self.handle_dispatch(dealer, frame)
else:
logger.info(f"Connection attempt to unknown sender: {sender}.")
Expand Down
5 changes: 2 additions & 3 deletions tests/ert/ui_tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,9 +952,8 @@ def test_tracking_missing_ecl(monkeypatch, tmp_path, caplog):
def test_that_connection_errors_do_not_effect_final_result(
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setattr(Client, "DEFAULT_MAX_RETRIES", 0)
monkeypatch.setattr(Client, "DEFAULT_TIMEOUT_MULTIPLIER", 0)
monkeypatch.setattr(Client, "CONNECTION_TIMEOUT", 1)
monkeypatch.setattr(Client, "DEFAULT_MAX_RETRIES", 1)
monkeypatch.setattr(Client, "DEFAULT_ACK_TIMEOUT", 1)
monkeypatch.setattr(EnsembleEvaluator, "CLOSE_SERVER_TIMEOUT", 0.01)
monkeypatch.setattr(Job, "DEFAULT_CHECKSUM_TIMEOUT", 0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_invalid_server():

with (
pytest.raises(ClientConnectionError),
Client(url, connection_timeout=1.0),
Client(url, ack_timeout=1.0),
):
pass

Expand All @@ -31,7 +31,8 @@ async def test_successful_sending(unused_tcp_port):
await server_started.wait()
messages_c1 = ["test_1", "test_2", "test_3"]
async with Client(url) as c1:
await c1._send(messages_c1)
for message in messages_c1:
await c1._send(message)

await server_task

Expand Down
2 changes: 1 addition & 1 deletion tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def async_zmq_server(port, handler):
async def test_no_connection_established(make_ee_config):
ee_config = make_ee_config()
monitor = Monitor(ee_config.get_connection_info())
monitor._connection_timeout = 0.1
monitor._ack_timeout = 0.1
with pytest.raises(ClientConnectionError):
async with monitor:
pass
Expand Down
Loading

0 comments on commit 9cce793

Please sign in to comment.