From db3bb4b7f034cd67cab6a517fbe15774ecc50f41 Mon Sep 17 00:00:00 2001 From: Lourens Veen Date: Sun, 17 Dec 2023 11:15:39 +0100 Subject: [PATCH] Fix incorrect assertions and improve Instance tests --- .../python/libmuscle/test/test_instance.py | 143 +++++++++++++----- 1 file changed, 101 insertions(+), 42 deletions(-) diff --git a/libmuscle/python/libmuscle/test/test_instance.py b/libmuscle/python/libmuscle/test/test_instance.py index 84e0973f..f500dd64 100644 --- a/libmuscle/python/libmuscle/test/test_instance.py +++ b/libmuscle/python/libmuscle/test/test_instance.py @@ -9,6 +9,7 @@ from libmuscle.communicator import Message from libmuscle.instance import Instance, InstanceFlags as IFlags from libmuscle.mpp_message import ClosePort +from libmuscle.port import Port from libmuscle.settings_manager import SettingsManager @@ -37,27 +38,69 @@ def sys_argv_instance() -> Generator[None, None, None]: @pytest.fixture -def instance(sys_argv_instance, tmp_path): +def overlay_settings(): + settings = Settings() + settings['test1'] = 12 + return settings + + +@pytest.fixture +def instance(sys_argv_instance, tmp_path, overlay_settings): + ports = { + Operator.F_INIT: ['in', 'not_connected'], + Operator.S: ['in_s', 'in_settings', 'not_connected_s'], + Operator.O_F: ['out', 'out_v']} + + def port_exists(name): + return [name for op, names in ports.items() if name in names] != [] + + def is_connected(name): + return 'not_connected' not in name + + def is_vector(name): + return name.endswith('_v') + + def get_port(name): + return Port( + name, + [op for op, names in ports.items() if name in names][0], + is_vector(name), is_connected(name), 2, + [13, 42, 10] if is_vector(name) else [13, 42]) + with patch('libmuscle.instance.MMPClient') as mmp_client, \ patch('libmuscle.instance.Communicator') as comm_type: communicator = MagicMock() - settings = Settings() - settings['test1'] = 12 - msg = Message(0.0, 1.0, 'message', settings) - communicator.receive_message.return_value = msg, 10.0 + communicator.port_exists = MagicMock(side_effect=port_exists) + communicator.is_connected = MagicMock(side_effect=is_connected) + communicator.is_vector = MagicMock(side_effect=is_vector) + communicator.get_port = MagicMock(side_effect=get_port) + + msg = Message(0.0, 1.0, 'message') + msg_with_settings = Message(0.0, 1.0, 'message', overlay_settings) + + def receive_message(name, slot, default): + if 'not_connected' in name: + return default, 10.0 + if 'settings' in name: + return msg_with_settings, 10.0 + return msg, 10.0 + + communicator.receive_message = MagicMock(side_effect=receive_message) + comm_type.return_value = communicator mmp_client_object = MagicMock() mmp_client_object.request_peers.return_value = (None, None, None) + checkpoint_info = (0.0, Checkpoints(), None, tmp_path) mmp_client_object.get_checkpoint_info.return_value = checkpoint_info + mmp_client.return_value = mmp_client_object - instance = Instance({ - Operator.F_INIT: ['in', 'not_connected'], - Operator.O_F: ['out']}) + instance = Instance(ports) instance._f_init_cache = dict() instance._f_init_cache[('in', None)] = msg + yield instance @@ -97,8 +140,6 @@ def test_create_instance( assert len(instance._settings_manager.overlay) == 0 mmp_client.assert_called_once_with( Reference('test_instance[13][42]'), 'localhost:9000') - assert mmp_client_object._register.called_with() - assert mmp_client_object._connect.called_with() comm_type.assert_called_with(Reference('test_instance'), [13, 42], ports, instance._profiler) assert instance._communicator == comm_type.return_value @@ -154,67 +195,73 @@ def test_get_setting(instance): def test_list_ports(instance): ports = instance.list_ports() - assert instance._communicator.list_ports.called_with() + instance._communicator.list_ports.assert_called_with() assert ports == instance._communicator.list_ports.return_value def test_is_vector_port(instance): - instance._communicator.get_port.return_value.is_vector = MagicMock( - return_value=True) - is_vector = instance.is_vector_port('out_port') - assert is_vector is True - assert instance._communicator.get_port.called_with('out_port') + assert instance.is_vector_port('out_v') + instance._communicator.get_port.assert_called_with('out_v') def test_send(instance, message): instance._trigger_manager._cpts_considered_until = 17.0 instance.send('out', message, 1) - assert instance._communicator.send_message.called_with( - 'out', message, 1, 17.0) + instance._communicator.send_message.assert_called_with('out', message, 1, 17.0) def test_send_invalid_port(instance, message): - instance._communicator.port_exists.return_value = False with pytest.raises(RuntimeError): instance.send('does_not_exist', message, 1) def test_receive(instance): - instance._communicator.get_port.return_value = MagicMock( - operator=Operator.F_INIT) + msg = instance.receive('in_s') + assert msg.timestamp == 0.0 + assert msg.next_timestamp == 1.0 + assert msg.data == 'message' + instance._communicator.receive_message.assert_called_with('in_s', None, None) + + +def test_receive_cached(instance): msg = instance.receive('in') assert msg.timestamp == 0.0 assert msg.next_timestamp == 1.0 - assert instance._communicator.receive_message.called_with( - 'in', None) assert msg.data == 'message' + instance._communicator.receive_message.assert_not_called() with pytest.raises(RuntimeError): instance.receive('in') def test_receive_default(instance): - instance._communicator.port_exists.return_value = True - port = instance._communicator.get_port.return_value - port.operator = Operator.F_INIT - port.is_connected.return_value = False - instance.receive('not_connected', 1, 'testing') - assert instance._communicator.receive_message.called_with( - 'not_connected', 1, 'testing') + default_msg = Message(1.0, 2.0, 'testing') + msg = instance.receive('not_connected_s', 1, default_msg) + instance._communicator.receive_message.assert_called_with( + 'not_connected_s', 1, default_msg) + assert msg == default_msg + + +def test_receive_default_cached(instance): + msg = instance.receive('not_connected', 1, Message(1.0, 2.0, 'testing')) + assert msg.timestamp == 1.0 + assert msg.next_timestamp == 2.0 + assert msg.data == 'testing' + instance._communicator.receive_message.assert_not_called() with pytest.raises(RuntimeError): instance.receive('not_connected', 1) def test_receive_invalid_port(instance): - instance._communicator.port_exists.return_value = False with pytest.raises(RuntimeError): instance.receive('does_not_exist', 1) -def test_receive_with_settings(instance): - msg = instance.receive_with_settings('in', 1) - assert (instance._communicator.receive_message - .called_with('in', 1)) +def test_receive_with_settings(instance, overlay_settings): + instance._settings_manager.overlay = overlay_settings + msg = instance.receive_with_settings('in_settings') + instance._communicator.receive_message.assert_called_with( + 'in_settings', None, None) assert msg.timestamp == 0.0 assert msg.next_timestamp == 1.0 assert msg.data == 'message' @@ -222,30 +269,42 @@ def test_receive_with_settings(instance): def test_receive_with_settings_default(instance): - instance.receive_with_settings('not_connected', 1, 'testing') - assert instance._communicator.receive_message.called_with( - 'not_connected', 1, 'testing') + settings = Settings() + settings['test1'] = 42 + default_msg = Message(1.0, 2.0, 'testing', settings) + msg = instance.receive_with_settings('not_connected_s', default=default_msg) + instance._communicator.receive_message.assert_called_with( + 'not_connected_s', None, msg) + assert msg.settings['test1'] == 42 def test_receive_parallel_universe(instance) -> None: instance._settings_manager.overlay['test2'] = 'test' with pytest.raises(RuntimeError): - instance.receive('in') + instance.receive('in_settings') def test_reuse_instance_receive_overlay(instance): instance._settings_manager.overlay = Settings() + test_base_settings = Settings() test_base_settings['test1'] = 24 test_base_settings['test2'] = [1.3, 2.0] + test_overlay = Settings() test_overlay['test2'] = 'abc' - recv = instance._communicator.receive_message + msg = Message(0.0, None, test_overlay, test_base_settings) + + recv = instance._communicator.receive_message + recv.reset_mock(side_effect=True) recv.return_value = msg, 0.0 + instance.reuse_instance() - assert instance._communicator.receive_message.called_with( - 'muscle_settings_in') + instance._communicator.receive_message.assert_called() + assert instance._communicator.receive_message.call_args[0][0] == ( + 'muscle_settings_in') + assert len(instance._settings_manager.overlay) == 2 assert instance._settings_manager.overlay['test1'] == 24 assert instance._settings_manager.overlay['test2'] == 'abc'