diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e8b5b4a0..2280b70f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,7 +25,7 @@ jobs: - uses: actions/setup-python@v4 with: python-version: '3.10' - - run: pip install pre-commit + - run: pip install "pre-commit<4.0.0" - run: pre-commit --version - run: pre-commit install - run: pre-commit run --all-files diff --git a/milatools/cli/code.py b/milatools/cli/code.py index 2c9ef573..47615d3e 100644 --- a/milatools/cli/code.py +++ b/milatools/cli/code.py @@ -13,6 +13,7 @@ MilatoolsUserError, currently_in_a_test, internet_on_compute_nodes, + running_inside_WSL, ) from milatools.utils.compute_node import ComputeNode, salloc, sbatch from milatools.utils.disk_quota import check_disk_quota @@ -193,6 +194,9 @@ async def launch_vscode_loop(code_command: str, compute_node: ComputeNode, path: f"ssh-remote+{compute_node.hostname}", path, ) + if running_inside_WSL(): + code_command_to_run = ("powershell.exe", *code_command_to_run) + await LocalV2.run_async(code_command_to_run, display=True) # TODO: BUG: This now requires two Ctrl+C's instead of one! console.print( diff --git a/tests/cli/test_code.py b/tests/cli/test_code.py new file mode 100644 index 00000000..4e09500d --- /dev/null +++ b/tests/cli/test_code.py @@ -0,0 +1,64 @@ +"""Unit tests for the `milatools.cli.code` module. + +TODO: There are quite a few tests in `tests/integration/test_code.py` that could be +moved here, since some of them aren't exactly "integration" tests. +""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +import milatools.cli.code +import milatools.cli.utils +from milatools.cli.utils import running_inside_WSL +from milatools.utils.compute_node import ComputeNode +from milatools.utils.local_v2 import LocalV2 + + +@pytest.fixture +def pretend_to_be_in_WSL( + request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch +): + # By default, pretend to be in WSL. Indirect parametrization can be used to + # overwrite this value for a given test (as is done below). + in_wsl = getattr(request, "param", True) + + _mock_running_inside_WSL = Mock(spec=running_inside_WSL, return_value=in_wsl) + monkeypatch.setattr( + milatools.cli.utils, + running_inside_WSL.__name__, # type: ignore + _mock_running_inside_WSL, + ) + monkeypatch.setattr( + milatools.cli.code, + running_inside_WSL.__name__, # type: ignore + _mock_running_inside_WSL, + ) + return in_wsl + + +@pytest.mark.parametrize("pretend_to_be_in_WSL", [True, False], indirect=True) +@pytest.mark.asyncio +async def test_code_from_WSL( + monkeypatch: pytest.MonkeyPatch, pretend_to_be_in_WSL: bool +): + # Mock the LocalV2 class so that we can inspect the call to `LocalV2.run_async`. + mock_localv2 = Mock(spec=LocalV2) + monkeypatch.setattr(milatools.cli.code, LocalV2.__name__, mock_localv2) + + await milatools.cli.code.launch_vscode_loop( + "code", Mock(spec=ComputeNode, hostname="foo"), "/bob/path" + ) + assert isinstance(mock_localv2.run_async, AsyncMock) + mock_localv2.run_async.assert_called_once_with( + ( + *(("powershell.exe",) if pretend_to_be_in_WSL else ()), + "code", + "--new-window", + "--wait", + "--remote", + "ssh-remote+foo", + "/bob/path", + ), + display=True, + ) diff --git a/tests/integration/test_code/test_code_mila__salloc_.txt b/tests/integration/test_code/test_code_mila__salloc_.txt index 14a6e463..053d4abc 100644 --- a/tests/integration/test_code/test_code_mila__salloc_.txt +++ b/tests/integration/test_code/test_code_mila__salloc_.txt @@ -2,7 +2,7 @@ Checking disk quota on $HOME... Disk usage: X / LIMIT GiB and X / LIMIT files (mila) $ cd $SCRATCH && salloc --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code salloc: -------------------------------------------------------------------------------------------------- -salloc: # Using default long partition +salloc: # Using default long-cpu partition (CPU-only) salloc: -------------------------------------------------------------------------------------------------- salloc: Granted job allocation JOB_ID Waiting for job JOB_ID to start. diff --git a/tests/integration/test_code/test_code_mila__sbatch_.txt b/tests/integration/test_code/test_code_mila__sbatch_.txt index 249496de..b980ae67 100644 --- a/tests/integration/test_code/test_code_mila__sbatch_.txt +++ b/tests/integration/test_code/test_code_mila__sbatch_.txt @@ -5,7 +5,7 @@ Disk usage: X / LIMIT GiB and X / LIMIT files JOB_ID sbatch: -------------------------------------------------------------------------------------------------- -sbatch: # Using default long partition +sbatch: # Using default long-cpu partition (CPU-only) sbatch: -------------------------------------------------------------------------------------------------- (localhost) $ echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob diff --git a/tests/integration/test_sync_command.py b/tests/integration/test_sync_command.py index 31f19b4a..36ae27bd 100644 --- a/tests/integration/test_sync_command.py +++ b/tests/integration/test_sync_command.py @@ -19,6 +19,7 @@ _install_vscode_extensions_task_function, sync_vscode_extensions, ) +from tests.integration.conftest import SLURM_CLUSTER from ..cli.common import ( requires_ssh_to_localhost, @@ -28,6 +29,10 @@ logger = get_logger(__name__) +@pytest.mark.xfail( + SLURM_CLUSTER == "mila", + reason="`code-server` procs are killed on the login nodes of the Mila cluster.", +) @pytest.mark.slow @pytest.mark.parametrize( "source",