Skip to content

Commit

Permalink
Add test for handle dispatcher and dispatcher messages in evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
xjules committed Dec 17, 2024
1 parent 9972325 commit 745d68b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 17 deletions.
5 changes: 0 additions & 5 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,23 +277,18 @@ async def forward_checksum(self, event: Event) -> None:
async def _server(self) -> None:
zmq_context = zmq.asyncio.Context()
try:
print("INIT ZMQ ...")
# Create and configure the ROUTER socket
self._router_socket: zmq.asyncio.Socket = zmq_context.socket(zmq.ROUTER)
self._router_socket.setsockopt(zmq.LINGER, 0)
if self._config.server_public_key and self._config.server_secret_key:
self._router_socket.curve_secretkey = self._config.server_secret_key
self._router_socket.curve_publickey = self._config.server_public_key
self._router_socket.curve_server = True

# Attempt to bind the ROUTER socket
# self._router_socket.bind(f"tcp://*:{self._config.router_port}")
if self._config.router_port:
self._router_socket.bind(f"tcp://*:{self._config.router_port}")
else:
self._router_socket.bind(self._config.url)
self._server_started.set()
print(f"ROUTER listens on {self._config.url}")
except zmq.error.ZMQError as e:
logger.error(f"ZMQ error encountered {e} during evaluator initialization")
print(f"ZMQ error encountered {e} during evaluator initialization")
Expand Down
34 changes: 22 additions & 12 deletions tests/ert/unit_tests/ensemble_evaluator/test_ensemble_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from hypothesis import given
from hypothesis import strategies as st
from pydantic import ValidationError

from _ert.events import (
EESnapshot,
Expand All @@ -19,7 +20,7 @@
RealizationSuccess,
event_to_json,
)
from _ert.forward_model_runner.client import Client
from _ert.forward_model_runner.client import CONNECT_MSG, DISCONNECT_MSG, Client
from ert.ensemble_evaluator import (
EnsembleEvaluator,
EnsembleSnapshot,
Expand Down Expand Up @@ -66,18 +67,27 @@ async def mock_failure(message, *args, **kwargs):
await evaluator.run_and_get_successful_realizations()


# TODO refactor this test
# async def test_when_dispatch_is_given_invalid_event_the_socket_is_closed(
# make_ee_config,
# ):
# evaluator = EnsembleEvaluator(TestEnsemble(0, 2, 2, id_="0"), make_ee_config())
async def test_evaluator_raises_on_invalid_dispatch_event(
make_ee_config,
):
evaluator = EnsembleEvaluator(TestEnsemble(0, 2, 2, id_="0"), make_ee_config())

with pytest.raises(ValidationError):
await evaluator.handle_dispatch(b"dispatcher-1", b"This is not an event!!")


async def test_evaluator_handles_dispatchers_connected(
make_ee_config,
):
evaluator = EnsembleEvaluator(TestEnsemble(0, 2, 2, id_="0"), make_ee_config())

# socket = MagicMock(spec=WebSocketServerProtocol)
# socket.__aiter__.return_value = ["invalid_json"]
# await evaluator.handle_dispatch(socket)
# socket.close.assert_called_once_with(
# code=1011, reason="failed handling message 'invalid_json'"
# )
await evaluator.handle_dispatch(b"dispatcher-1", CONNECT_MSG.encode("utf-8"))
await evaluator.handle_dispatch(b"dispatcher-2", CONNECT_MSG.encode("utf-8"))
assert not evaluator._dispatchers_empty.is_set()
assert evaluator._dispatchers_connected == {b"dispatcher-1", b"dispatcher-2"}
await evaluator.handle_dispatch(b"dispatcher-1", DISCONNECT_MSG.encode("utf-8"))
await evaluator.handle_dispatch(b"dispatcher-2", DISCONNECT_MSG.encode("utf-8"))
assert evaluator._dispatchers_empty.is_set()


async def test_no_config_raises_valueerror_when_running():
Expand Down

0 comments on commit 745d68b

Please sign in to comment.