Skip to content

Commit

Permalink
Add some tests for SlurmRemote
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Oct 6, 2023
1 parent a2c4c36 commit d5b2061
Showing 1 changed file with 94 additions and 2 deletions.
96 changes: 94 additions & 2 deletions tests/cli/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from pytest_regressions.file_regression import FileRegressionFixture
from typing_extensions import ParamSpec

from milatools.cli.remote import QueueIO, Remote, get_first_node_name
from milatools.cli.remote import (
QueueIO,
Remote,
SlurmRemote,
get_first_node_name,
)

# TODO: Enable running the tests "for real" on the mila cluster using a flag?
# - This would require us to use "proper" commands e.g. 'echo OK' can't output "bobobo".
Expand Down Expand Up @@ -520,7 +525,94 @@ def remote(mock_connection: Connection, host: str) -> Remote:


def test_ensure_allocation(remote: Remote):
assert remote.ensure_allocation() == {"node_name": remote.hostname}, None
assert remote.ensure_allocation() == ({"node_name": remote.hostname}, None)


class TestSlurmRemote:
@pytest.mark.parametrize("persist", [True, False])
def test_init_(self, mock_connection: Connection, persist: bool):
alloc = ["--time=00:01:00"]
transforms = [lambda x: f"echo Hello && {x}"]
persist: bool = False
remote = SlurmRemote(
mock_connection, alloc=alloc, transforms=transforms, persist=persist
)
# TODO: This kind of test feels a bit dumb.
assert remote.connection is mock_connection
assert remote._persist == persist
assert remote.transforms == [
*transforms,
remote.srun_transform_persist if persist else remote.srun_transform,
]

def test_srun_transform(self, mock_connection: Connection):
alloc = ["--time=00:01:00"]
transforms = [lambda x: f"echo Hello && {x}"]
persist: bool = False
remote = SlurmRemote(
mock_connection, alloc=alloc, transforms=transforms, persist=persist
)
# Transforms aren't used here. Seems a bit weird for this to be a public method then, no?
assert remote.srun_transform("bob") == "srun --time=00:01:00 bash -c bob"

@pytest.mark.skip(reason="Seems a bit hard to test for what it's worth..")
def test_srun_transform_persist(self, mock_connection: Connection):
alloc = ["--time=00:01:00"]
transforms = [lambda x: f"echo Hello && {x}"]
persist: bool = False
remote = SlurmRemote(
mock_connection, alloc=alloc, transforms=transforms, persist=persist
)
output_file = "<some_file>"
assert (
remote.srun_transform_persist("bob")
== f"bob; touch {output_file}; tail -n +1 -f {output_file}"
)

@pytest.mark.parametrize("persist", [True, False, None])
def test_with_transforms(self, mock_connection: Connection, persist: bool | None):
alloc = ["--time=00:01:00"]
transforms = [lambda x: f"echo Hello && {x}"]
original_persist: bool = False
remote = SlurmRemote(
mock_connection,
alloc=alloc,
transforms=transforms,
persist=original_persist,
)

new_transforms = [lambda x: f"echo 'this is printed after the command' && {x}"]
transformed = remote.with_transforms(*new_transforms, persist=persist)
# NOTE: Feels dumb to do this. Not sure what I should be doing otherwise.
assert transformed.connection == remote.connection
assert transformed.alloc == remote.alloc
assert transformed.transforms == [*remote.transforms[:-1], *transforms]
assert transformed._persist == original_persist if persist is None else persist

@pytest.mark.parametrize("persist", [True, False])
def test_persist(self, mock_connection: Connection, persist: bool):
alloc = ["--time=00:01:00"]
transforms = [lambda x: f"echo Hello && {x}"]
remote = SlurmRemote(
mock_connection, alloc=alloc, transforms=transforms, persist=persist
)
persisted = remote.persist()

# NOTE: Feels dumb to do this. Not sure what I should be doing otherwise.
assert persisted.connection == remote.connection
assert persisted.alloc == remote.alloc
assert persisted.transforms == [*remote.transforms[:-1]]
assert persisted._persist is True

@pytest.mark.parametrize("persist", [True, False])
def test_ensure_allocation(self, mock_connection: Connection, persist: bool):
alloc = ["--time=00:01:00"]
transforms = [lambda x: f"echo Hello && {x}"]
remote = SlurmRemote(
mock_connection, alloc=alloc, transforms=transforms, persist=persist
)
persisted = remote.ensure_allocation()
raise NotImplementedError("TODO: Imporant and potentially complicated test")


def test_QueueIO(file_regression: FileRegressionFixture):
Expand Down

0 comments on commit d5b2061

Please sign in to comment.