Skip to content

Commit

Permalink
Setup VsCode settings for remote development during mila init (#71)
Browse files Browse the repository at this point in the history
* Make functions for each chunk of `mila init` fn

Signed-off-by: Fabrice Normandin <[email protected]>

* Setup VsCode settings during `mila init`

Signed-off-by: Fabrice Normandin <[email protected]>

* Add test for the VsCode settings part of mila init

Signed-off-by: Fabrice Normandin <[email protected]>

* Add another test to check overwriting of setting

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix indent of generated VsCode settings file

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix the indent in the regression files

Signed-off-by: Fabrice Normandin <[email protected]>

* Use indent=4 spaces instead of tab in json file

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix isort issue

Signed-off-by: Fabrice Normandin <[email protected]>

* Add tests for the vscode utils

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice authored Nov 24, 2023
1 parent 68cd311 commit e64e2dd
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 32 deletions.
37 changes: 17 additions & 20 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
from urllib.parse import urlencode

import questionary as qn
from invoke import UnexpectedExit
from invoke.exceptions import UnexpectedExit
from typing_extensions import TypedDict

from ..version import version as mversion
from .init_command import (
create_ssh_keypair,
setup_ssh_config,
setup_vscode_settings,
setup_windows_ssh_config_from_wsl,
)
from .local import Local
Expand Down Expand Up @@ -394,15 +395,22 @@ def init():
print("Checking ssh config")

ssh_config = setup_ssh_config()
# TODO: Move the rest of this command to functions in the init_command module,
# so they can more easily be tested.

print("# OK")

#############################
# Step 2: Passwordless auth #
#############################
# if we're running on WSL, we actually just copy the id_rsa + id_rsa.pub and the
# ~/.ssh/config to the Windows ssh directory (taking care to remove the
# ControlMaster-related entries) so that the user doesn't need to install Python on
# the Windows side.
if running_inside_WSL():
setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config)

setup_passwordless_ssh_access()
setup_keys_on_login_node()
setup_vscode_settings()
print_welcome_message()


def setup_passwordless_ssh_access():
print("Checking passwordless authentication")

here = Local()
Expand Down Expand Up @@ -441,10 +449,8 @@ def init():
else:
exit("No passwordless login.")

#####################################
# Step 3: Set up keys on login node #
#####################################

def setup_keys_on_login_node():
print("Checking connection to compute nodes")

remote = Remote("mila")
Expand Down Expand Up @@ -482,17 +488,8 @@ def init():
else:
exit("You will not be able to SSH to a compute node")

# TODO: IF we're running on WSL, we could probably actually just copy the
# id_rsa.pub and the config to the Windows paths (taking care to remove the
# ControlMaster-related entries) so that the user doesn't need to install Python on
# the Windows side.
if running_inside_WSL():
setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config)

###################
# Welcome message #
###################

def print_welcome_message():
print(T.bold_cyan("=" * 60))
print(T.bold_cyan("Congrats! You are now ready to start working on the cluster!"))
print(T.bold_cyan("=" * 60))
Expand Down
91 changes: 80 additions & 11 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import copy
import difflib
import json
import shutil
import subprocess
import sys
Expand All @@ -13,11 +15,14 @@

from .local import Local
from .utils import SSHConfig, T, running_inside_WSL, yn

WINDOWS_UNSUPPORTED_KEYS = ["ControlMaster", "ControlPath", "ControlPersist"]
from .vscode_utils import (
get_expected_vscode_settings_json_path,
vscode_installed,
)

logger = get_logger(__name__)

WINDOWS_UNSUPPORTED_KEYS = ["ControlMaster", "ControlPath", "ControlPersist"]
HOSTS = ["mila", "mila-cpu", "*.server.mila.quebec !*login.server.mila.quebec"]
"""List of host entries that get added to the SSH configuration by `mila init`."""

Expand Down Expand Up @@ -212,6 +217,69 @@ def create_ssh_keypair(ssh_private_key_path: Path, local: Local) -> None:
local.run("ssh-keygen", "-f", str(ssh_private_key_path), "-t", "rsa", "-N=''")


def setup_vscode_settings():
print("Setting up VsCode settings for Remote development.")

# TODO: Could also change some other useful settings as needed.

# For example, we could skip a prompt if we had the qualified node name:
# remote_platform = settings_json.get("remote.SSH.remotePlatform", {})
# remote_platform.setdefault(fully_qualified_node_name, "linux")
# settings_json["remote.SSH.remotePlatform"] = remote_platform
if not vscode_installed():
# Display a message inviting the user to install VsCode:
warnings.warn(
T.orange(
"Visual Studio Code doesn't seem to be installed on your machine "
"(either that, or the `code` command is not available on the "
"command-line.)\n"
"We would recommend installing Visual Studio Code if you want to "
"easily edit code on the cluster with the `mila code` command. "
)
)
return

try:
_update_vscode_settings_json({"remote.SSH.connectTimeout": 60})
except Exception as err:
logger.warning(
f"Unable to setup VsCode settings for remote development: {err}\n"
f"Skipping and leaving the settings unchanged.",
exc_info=err,
)


def _update_vscode_settings_json(new_values: dict[str, Any]) -> None:
vscode_settings_json_path = get_expected_vscode_settings_json_path()

settings_json: dict[str, Any] = {}
if vscode_settings_json_path.exists():
logger.info(f"Reading VsCode settings from {vscode_settings_json_path}")
with open(vscode_settings_json_path) as f:
settings_json = json.load(f)

settings_before = copy.deepcopy(settings_json)
settings_json.update(
{k: v for k, v in new_values.items() if k not in settings_json}
)

if settings_json == settings_before or not ask_to_confirm_changes(
before=json.dumps(settings_before, indent=4),
after=json.dumps(settings_json, indent=4),
path=vscode_settings_json_path,
):
print(f"Didn't change the VsCode settings at {vscode_settings_json_path}")
return

if not vscode_settings_json_path.exists():
logger.info(
f"Creating a new VsCode settings file at {vscode_settings_json_path}"
)
vscode_settings_json_path.parent.mkdir(parents=True, exist_ok=True)
with open(vscode_settings_json_path, "w") as f:
json.dump(settings_json, f, indent=4)


def _setup_ssh_config_file(config_file_path: str | Path) -> Path:
# Save the original value for the prompt. (~/.ssh/config looks better on the
# command-line).
Expand Down Expand Up @@ -244,17 +312,12 @@ def _setup_ssh_config_file(config_file_path: str | Path) -> Path:
return config_file


def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool:
print(
T.bold(
f"The following modifications will be made to your SSH config file at "
f"{ssh_config.path}:\n"
)
)
def ask_to_confirm_changes(before: str, after: str, path: str | Path) -> bool:
print(T.bold(f"The following modifications will be made to {path}:\n"))
diff_lines = list(
difflib.unified_diff(
(previous + "\n").splitlines(True),
(ssh_config.cfg.config() + "\n").splitlines(True),
before.splitlines(True),
after.splitlines(True),
)
)
for line in diff_lines[2:]:
Expand All @@ -267,6 +330,12 @@ def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool:
return yn("\nIs this OK?")


def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool:
before = previous + "\n"
after = ssh_config.cfg.config() + "\n"
return ask_to_confirm_changes(before, after, ssh_config.path)


def _get_username(ssh_config: SSHConfig) -> str:
# Check for a mila entry in ssh config
# NOTE: This also supports the case where there's a 'HOST mila some_alias_for_mila'
Expand Down
2 changes: 1 addition & 1 deletion milatools/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class SSHConfig:
"""Wrapper around sshconf with some extra niceties."""

def __init__(self, path: str | Path):
self.path = path
self.path = Path(path)
self.cfg = read_ssh_config(path)
# self.add = self.cfg.add
self.remove = self.cfg.remove
Expand Down
38 changes: 38 additions & 0 deletions milatools/cli/vscode_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import shutil
import subprocess
import sys
from logging import getLogger as get_logger
from pathlib import Path

logger = get_logger(__name__)


def running_inside_WSL() -> bool:
return sys.platform == "linux" and bool(shutil.which("powershell.exe"))


def get_expected_vscode_settings_json_path() -> Path:
if sys.platform == "win32":
return Path.home() / "AppData\\Roaming\\Code\\User\\settings.json"
if sys.platform == "darwin": # MacOS
return (
Path.home()
/ "Library"
/ "Application Support"
/ "Code"
/ "User"
/ "settings.json"
)
if running_inside_WSL():
# Need to get the Windows Home directory, not the WSL one!
windows_username = subprocess.getoutput("powershell.exe '$env:UserName'")
return Path(
f"/mnt/c/Users/{windows_username}/AppData/Roaming/Code/User/settings.json"
)
# Linux:
return Path.home() / ".config/Code/User/settings.json"


def vscode_installed() -> bool:
return bool(shutil.which(os.environ.get("MILATOOLS_CODE_COMMAND", "code")))
91 changes: 91 additions & 0 deletions tests/cli/test_init_command.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import json
import subprocess
import textwrap
from functools import partial
Expand All @@ -19,6 +20,7 @@
create_ssh_keypair,
get_windows_home_path_in_wsl,
setup_ssh_config,
setup_vscode_settings,
setup_windows_ssh_config_from_wsl,
)
from milatools.cli.local import Local
Expand Down Expand Up @@ -621,6 +623,95 @@ def test_setup_windows_ssh_config_from_wsl(
file_regression.check(expected_text, extension=".md")


@pytest.mark.parametrize(
"initial_settings", [None, {}, {"foo": "bar"}, {"remote.SSH.connectTimeout": 123}]
)
@pytest.mark.parametrize("accept_changes", [True, False], ids=["accept", "reject"])
def test_setup_vscode_settings(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
input_pipe: PipeInput,
initial_settings: dict | None,
file_regression: FileRegressionFixture,
accept_changes: bool,
):
vscode_settings_json_path = tmp_path / "settings.json"
if initial_settings is not None:
with open(vscode_settings_json_path, "w") as f:
json.dump(initial_settings, f, indent=4)

monkeypatch.setattr(
init_command,
init_command.vscode_installed.__name__,
Mock(spec=init_command.vscode_installed, return_value=True),
)
monkeypatch.setattr(
init_command,
init_command.get_expected_vscode_settings_json_path.__name__,
Mock(
spec=init_command.get_expected_vscode_settings_json_path,
return_value=vscode_settings_json_path,
),
)

user_inputs = ["y" if accept_changes else "n"]
for user_input in user_inputs:
input_pipe.send_text(user_input)

setup_vscode_settings()

resulting_contents: str | None = None
resulting_settings: dict | None = None

if not accept_changes and initial_settings is None:
# Shouldn't create the file if we don't accept the changes and there's no
# initial file.
assert not vscode_settings_json_path.exists()

if vscode_settings_json_path.exists():
resulting_contents = vscode_settings_json_path.read_text()
resulting_settings = json.loads(resulting_contents)
assert isinstance(resulting_settings, dict)

if not accept_changes:
if initial_settings is None:
assert not vscode_settings_json_path.exists()
return # skip creating the regression file in that case.
assert resulting_settings == initial_settings
return

assert resulting_contents is not None
assert resulting_settings is not None

expected_text = "\n".join(
[
f"Calling `{setup_vscode_settings.__name__}()` with "
+ (
"\n".join(
[
"this initial content:",
"",
"```json",
json.dumps(initial_settings, indent=4),
"```",
]
)
if initial_settings is not None
else "no initial VsCode settings file"
),
"",
f"and these user inputs: {tuple(user_inputs)}",
"leads the following VsCode settings file:",
"",
"```json",
resulting_contents,
"```",
]
)

file_regression.check(expected_text, extension=".md")


def test_setup_windows_ssh_config_from_wsl_copies_keys(
tmp_path: Path,
linux_ssh_config: SSHConfig,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Calling `setup_vscode_settings()` with no initial VsCode settings file

and these user inputs: ('y',)
leads the following VsCode settings file:

```json
{
"remote.SSH.connectTimeout": 60
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Calling `setup_vscode_settings()` with this initial content:

```json
{}
```

and these user inputs: ('y',)
leads the following VsCode settings file:

```json
{
"remote.SSH.connectTimeout": 60
}
```
Loading

0 comments on commit e64e2dd

Please sign in to comment.