diff --git a/tests/cli/test_remote.py b/tests/cli/test_remote.py index a66b556d..43f551a0 100644 --- a/tests/cli/test_remote.py +++ b/tests/cli/test_remote.py @@ -6,7 +6,6 @@ import typing import unittest import unittest.mock -from pathlib import Path from typing import Callable from unittest.mock import Mock, create_autospec @@ -305,7 +304,12 @@ def test_display( r = Remote(host, connection=mock_connection) r.display(message) output = capsys.readouterr().out - assert output == f"({host}) $ {message}\n" + # NOTE: This way of testing is also resilient to Pytest's `-s` option being used, + # since color codes are added to the output. + assert output in ( + f"\x1b[1m\x1b[36m({host}) $ {message}\x1b(B\x1b[m\n", + f"({host}) $ {message}", + ) @disable_internet_access @@ -397,7 +401,9 @@ def test_extract( pty: bool, hide: bool, ): - """TODO: It's very hard to write this test in such a way where it isn't testing itself...""" + """TODO: It's very hard to write this test in such a way where it doesn't just test + itself... + """ test_command = "echo 'hello my name is $USER'" command_output = "hello my name is bob" @@ -411,7 +417,8 @@ def test_extract( _name="mock_runner", ) - # NOTE: The runner needs to write stuff to into the out_stream. This is a bit tricky. + # NOTE: The runner needs to write stuff to into the out_stream. This is a bit + # tricky. write_stuff_was_called = False def write_stuff( @@ -443,8 +450,8 @@ def write_stuff( ) mock_connection.run.return_value = mock_promise - # TODO: This makes the test pass, but it becomes pretty meaningless at this point. Both the - # Promise/Result, Runner, etc don't get used. I'm (@lebrice) not sure + # TODO: This makes the test pass, but it becomes pretty meaningless at this point. + # Both the Promise/Result, Runner, etc don't get used. I'm (@lebrice) not sure mock_connection.run.side_effect = write_stuff r = Remote(hostname=host, connection=mock_connection) @@ -467,7 +474,8 @@ def write_stuff( @disable_internet_access def test_get(mock_connection: Connection, host: str): - # TODO: Make this test smarter? or no need? (because we'd be testing fabric at that point?) + # TODO: Make this test smarter? or no need? (because we'd be testing fabric at that + # point?) r = Remote(host, connection=mock_connection) _result = r.get("foo", "bar") mock_connection.get.assert_called_once_with("foo", "bar") @@ -528,12 +536,19 @@ def test_ensure_allocation(remote: Remote): assert remote.ensure_allocation() == ({"node_name": remote.hostname}, None) +def some_transform(x: str) -> str: + return f"echo Hello && {x}" + + +def some_other_transform(x: str) -> str: + return f"echo 'this is printed after the command' && {x}" + + class TestSlurmRemote: @pytest.mark.parametrize("persist", [True, False]) def test_init_(self, mock_connection: Connection, persist: bool): alloc = ["--time=00:01:00"] - transforms = [lambda x: f"echo Hello && {x}"] - persist: bool = False + transforms = [some_transform] remote = SlurmRemote( mock_connection, alloc=alloc, transforms=transforms, persist=persist ) @@ -545,20 +560,22 @@ def test_init_(self, mock_connection: Connection, persist: bool): remote.srun_transform_persist if persist else remote.srun_transform, ] + @can_run_for_real def test_srun_transform(self, mock_connection: Connection): alloc = ["--time=00:01:00"] - transforms = [lambda x: f"echo Hello && {x}"] + transforms = [some_transform] persist: bool = False remote = SlurmRemote( mock_connection, alloc=alloc, transforms=transforms, persist=persist ) - # Transforms aren't used here. Seems a bit weird for this to be a public method then, no? + # Transforms aren't used here. Seems a bit weird for this to be a public method + # then, no? assert remote.srun_transform("bob") == "srun --time=00:01:00 bash -c bob" @pytest.mark.skip(reason="Seems a bit hard to test for what it's worth..") def test_srun_transform_persist(self, mock_connection: Connection): alloc = ["--time=00:01:00"] - transforms = [lambda x: f"echo Hello && {x}"] + transforms = [some_transform] persist: bool = False remote = SlurmRemote( mock_connection, alloc=alloc, transforms=transforms, persist=persist @@ -572,7 +589,7 @@ def test_srun_transform_persist(self, mock_connection: Connection): @pytest.mark.parametrize("persist", [True, False, None]) def test_with_transforms(self, mock_connection: Connection, persist: bool | None): alloc = ["--time=00:01:00"] - transforms = [lambda x: f"echo Hello && {x}"] + transforms = [some_transform] original_persist: bool = False remote = SlurmRemote( mock_connection, @@ -580,19 +597,26 @@ def test_with_transforms(self, mock_connection: Connection, persist: bool | None transforms=transforms, persist=original_persist, ) - - new_transforms = [lambda x: f"echo 'this is printed after the command' && {x}"] + new_transforms = [some_other_transform] transformed = remote.with_transforms(*new_transforms, persist=persist) # NOTE: Feels dumb to do this. Not sure what I should be doing otherwise. assert transformed.connection == remote.connection assert transformed.alloc == remote.alloc - assert transformed.transforms == [*remote.transforms[:-1], *transforms] - assert transformed._persist == original_persist if persist is None else persist + assert transformed.transforms == [ + some_transform, + some_other_transform, + ( + transformed.srun_transform_persist + if persist + else transformed.srun_transform + ), + ] + assert transformed._persist == (remote._persist if persist is None else persist) @pytest.mark.parametrize("persist", [True, False]) def test_persist(self, mock_connection: Connection, persist: bool): alloc = ["--time=00:01:00"] - transforms = [lambda x: f"echo Hello && {x}"] + transforms = [some_transform] remote = SlurmRemote( mock_connection, alloc=alloc, transforms=transforms, persist=persist ) @@ -601,17 +625,50 @@ def test_persist(self, mock_connection: Connection, persist: bool): # NOTE: Feels dumb to do this. Not sure what I should be doing otherwise. assert persisted.connection == remote.connection assert persisted.alloc == remote.alloc - assert persisted.transforms == [*remote.transforms[:-1]] + assert persisted.transforms == [ + some_transform, + persisted.srun_transform_persist, + ] assert persisted._persist is True + @disable_internet_access @pytest.mark.parametrize("persist", [True, False]) def test_ensure_allocation(self, mock_connection: Connection, persist: bool): alloc = ["--time=00:01:00"] - transforms = [lambda x: f"echo Hello && {x}"] + transforms = [some_transform] remote = SlurmRemote( mock_connection, alloc=alloc, transforms=transforms, persist=persist ) - persisted = remote.ensure_allocation() + write_stuff_was_called = False + command_output = "\n".join( + [ + "bob-002", + "Submitted batch job 1234", + ] + ) + + def write_stuff( + command: str, + asynchronous: bool, + hide: bool, + warn: bool, + pty: bool, + out_stream: QueueIO, + ): + nonlocal write_stuff_was_called + assert command == "echo @@@ $(hostname) @@@ && sleep 1000d" + # patterns={ + # "node_name": "@@@ ([^ ]+) @@@", + # "jobid": "Submitted batch job ([0-9]+)", + # }, + assert hide is True + write_stuff_was_called = True + out_stream.write(command_output) + return unittest.mock.DEFAULT + # return invoke.runners.Promise(mock_runner) + + mock_connection.run.side_effect = write_stuff + remote.ensure_allocation() raise NotImplementedError("TODO: Imporant and potentially complicated test")