Skip to content

Commit

Permalink
Fix bug with mila code --cluster=<DRAC> (#115)
Browse files Browse the repository at this point in the history
* Hotfix for bug in `mila code --cluster=<DRAC>`

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix running sync vscode extensions in background

Signed-off-by: Fabrice Normandin <[email protected]>

* Use ssh key from ssh config for ssh-copy-id

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix dumb unit test

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix broken test for `mila init`

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix broken test for `make_process`

Signed-off-by: Fabrice Normandin <[email protected]>

* Always run `ssh-copy-id` (with the right key)

Signed-off-by: Fabrice Normandin <[email protected]>

* Change where the compute node setup occurs in init

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix error in test for syncing vscode extensions

Signed-off-by: Fabrice Normandin <[email protected]>

* Add temporary "fix" for failing test

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove flaky check from test_ensure_allocation

Signed-off-by: Fabrice Normandin <[email protected]>

* Always cd to $SCRATCH before salloc/sbatch/srun

Signed-off-by: Fabrice Normandin <[email protected]>

* Adjust unit tests following `cd $SCRATCH` change

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix test and make it slightly more agnostic to imp

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix failing check for the workdir in test_code

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice authored Apr 17, 2024
1 parent 5adb85c commit e3bc550
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 88 deletions.
13 changes: 10 additions & 3 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,8 @@ def code(
)
elif no_internet_on_compute_nodes(cluster):
# Sync the VsCode extensions from the local machine over to the target cluster.
run_in_the_background = False # if "pytest" not in sys.modules else True
# TODO: Make this happen in the background (without overwriting the output).
run_in_the_background = False
print(
console.log(
f"[cyan]Installing VSCode extensions that are on the local machine on "
Expand Down Expand Up @@ -680,11 +681,17 @@ def code(
if persist:
print("This allocation is persistent and is still active.")
print("To reconnect to this node:")
print(T.bold(f" mila code {path} --node {node_name}"))
print(
T.bold(
f" mila code {path} "
+ (f"--cluster={cluster} " if cluster != "mila" else "")
+ f"--node {node_name}"
)
)
print("To kill this allocation:")
assert data is not None
assert "jobid" in data
print(T.bold(f" ssh mila scancel {data['jobid']}"))
print(T.bold(f" ssh {cluster} scancel {data['jobid']}"))


def connect(identifier: str, port: int | None):
Expand Down
65 changes: 46 additions & 19 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import questionary as qn
from invoke.exceptions import UnexpectedExit

from milatools.utils.remote_v2 import SSH_CONFIG_FILE

from ..utils.vscode_utils import (
get_expected_vscode_settings_json_path,
vscode_installed,
Expand Down Expand Up @@ -238,20 +240,23 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool:

here = Local()
sshdir = Path.home() / ".ssh"
ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"

# Check if there is a public key file in ~/.ssh
if not list(sshdir.glob("id*.pub")):
if yn("You have no public keys. Generate one?"):
# Run ssh-keygen with the given location and no passphrase.
ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"
create_ssh_keypair(ssh_private_key_path, here)
else:
print("No public keys.")
return False

# TODO: This uses the public key set in the SSH config file, which may (or may not)
# be the random id*.pub file that was just checked for above.
success = setup_passwordless_ssh_access_to_cluster("mila")
if not success:
return False
setup_keys_on_login_node("mila")

drac_clusters_in_ssh_config: list[str] = []
hosts_in_config = ssh_config.hosts()
Expand All @@ -277,6 +282,7 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool:
success = setup_passwordless_ssh_access_to_cluster(drac_cluster)
if not success:
return False
setup_keys_on_login_node(drac_cluster)
return True


Expand All @@ -293,28 +299,38 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.")
# TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of
# the default.
ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"

from paramiko.config import SSHConfig

config = SSHConfig.from_path(str(SSH_CONFIG_FILE))
identity_file = config.lookup(cluster).get("identityfile", "~/.ssh/id_rsa")
# Seems to be a list for some reason?
if isinstance(identity_file, list):
assert identity_file
identity_file = identity_file[0]
ssh_private_key_path = Path(identity_file).expanduser()
ssh_public_key_path = ssh_private_key_path.with_suffix(".pub")
assert ssh_public_key_path.exists()
# TODO: This will fail for clusters with 2FA.
if check_passwordless(cluster):
logger.info(f"Passwordless SSH access to {cluster} is already setup correctly.")
return True

if not yn(
f"Your public key does not appear be registered on the {cluster} cluster. "
"Register it?"
):
print("No passwordless login.")
return False

print("Please enter your password when prompted.")
# TODO: This will fail on Windows for clusters with 2FA.
# if check_passwordless(cluster):
# logger.info(f"Passwordless SSH access to {cluster} is already setup correctly.")
# return True
# if not yn(
# f"Your public key does not appear be registered on the {cluster} cluster. "
# "Register it?"
# ):
# print("No passwordless login.")
# return False
print("Please enter your password if prompted.")
if sys.platform == "win32":
# NOTE: This is to remove extra '^M' characters that would be added at the end
# of the file on the remote!
public_key_contents = ssh_public_key_path.read_text().replace("\r\n", "\n")
command = (
"ssh",
"-i",
str(ssh_private_key_path),
"-o",
"StrictHostKeyChecking=no",
cluster,
Expand All @@ -328,7 +344,15 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
f.seek(0)
subprocess.run(command, check=True, text=False, stdin=f)
else:
here.run("ssh-copy-id", "-o", "StrictHostKeyChecking=no", cluster, check=True)
here.run(
"ssh-copy-id",
"-i",
str(ssh_private_key_path),
"-o",
"StrictHostKeyChecking=no",
cluster,
check=True,
)

# double-check that this worked.
if not check_passwordless(cluster):
Expand All @@ -337,14 +361,17 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
return True


def setup_keys_on_login_node():
def setup_keys_on_login_node(cluster: str = "mila"):
#####################################
# Step 3: Set up keys on login node #
#####################################

print("Checking connection to compute nodes")

remote = Remote("mila")
print(
f"Checking connection to compute nodes on the {cluster} cluster. "
"This is required for `mila code` to work properly."
)
# todo: avoid re-creating the `Remote` here, since it goes through 2FA each time!
remote = Remote(cluster)
try:
pubkeys = remote.get_lines("ls -t ~/.ssh/id*.pub")
print("# OK")
Expand Down
23 changes: 11 additions & 12 deletions milatools/cli/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing_extensions import Self, TypedDict, deprecated

from .utils import (
DRAC_CLUSTERS,
SSHConnectionError,
T,
cluster_to_connect_kwargs,
Expand Down Expand Up @@ -344,7 +343,6 @@ def extract(
# something like a `io.StringIO` here instead, and create an object that manages
# reading from it, and pass that `io.StringIO` buffer to `self.run`.
qio: TextIO = QueueIO()

promise = self.run(
cmd,
hide=hide,
Expand Down Expand Up @@ -467,7 +465,10 @@ def __init__(
)

def srun_transform(self, cmd: str) -> str:
return shlex.join(["srun", *self.alloc, "bash", "-c", cmd])
cmd = shlex.join(["srun", *self.alloc, "bash", "-c", cmd])
# We need to cd to $SCRATCH before we can run jobs with `srun` on some clusters.
cmd = f"cd $SCRATCH && {cmd}"
return cmd

def srun_transform_persist(self, cmd: str) -> str:
tag = time.time_ns()
Expand All @@ -482,10 +483,8 @@ def srun_transform_persist(self, cmd: str) -> str:
self.puttext(text=batch, dest=batch_file)
self.simple_run(f"chmod +x {batch_file}")
cmd = shlex.join(["sbatch", *self.alloc, str(batch_file)])

# NOTE: We need to cd to $SCRATCH before we run `sbatch` on DRAC clusters.
if self.connection.host in DRAC_CLUSTERS:
cmd = f"cd $SCRATCH && {cmd}"
# We need to cd to $SCRATCH before we run `sbatch` on some SLURM clusters.
cmd = f"cd $SCRATCH && {cmd}"
return f"{cmd}; touch {output_file}; tail -n +1 -f {output_file}"

def with_transforms(
Expand Down Expand Up @@ -518,9 +517,10 @@ def ensure_allocation(
- a dict with the compute node name (without the jobid)
- a `fabric.runners.Remote` object connected to the *login* node.
"""

if self._persist:
login_node_runner, results = self.extract(
"echo @@@ $(hostname) @@@ && sleep 1000d",
"echo @@@ $SLURMD_NODENAME @@@ && sleep 1000d",
patterns={
"node_name": "@@@ ([^ ]+) @@@",
"jobid": "Submitted batch job ([0-9]+)",
Expand All @@ -535,10 +535,9 @@ def ensure_allocation(
else:
remote = Remote(hostname=self.hostname, connection=self.connection)
command = shlex.join(["salloc", *self.alloc])
# NOTE: On some DRAC clusters, it's required to first cd to $SCRATCH or
# /projects before submitting a job.
if self.connection.host in DRAC_CLUSTERS:
command = f"cd $SCRATCH && {command}"
# We need to cd to $SCRATCH before we can run `salloc` on some clusters.
command = f"cd $SCRATCH && {command}"

proc, results = remote.extract(
command,
patterns={"node_name": "salloc: Nodes ([^ ]+) are ready for job"},
Expand Down
4 changes: 3 additions & 1 deletion milatools/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ def make_process(
) -> multiprocessing.Process:
# Tiny wrapper around the `multiprocessing.Process` init to detect if the args and
# kwargs don't match the target signature using typing instead of at runtime.
return multiprocessing.Process(target=target, daemon=True, args=args, kwargs=kwargs)
return multiprocessing.Process(
target=target, daemon=False, args=args, kwargs=kwargs
)


def currently_in_a_test() -> bool:
Expand Down
2 changes: 1 addition & 1 deletion milatools/utils/vscode_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def find_code_server_executable(
def parse_vscode_extensions_versions(
list_extensions_output_lines: list[str],
) -> dict[str, str]:
extensions = list_extensions_output_lines
extensions = [line for line in list_extensions_output_lines if "@" in line]

def _extension_name_and_version(extension: str) -> tuple[str, str]:
# extensions should include name@version since we use --show-versions.
Expand Down
7 changes: 7 additions & 0 deletions tests/cli/test_init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
create_ssh_keypair,
get_windows_home_path_in_wsl,
has_passphrase,
setup_keys_on_login_node,
setup_passwordless_ssh_access,
setup_passwordless_ssh_access_to_cluster,
setup_ssh_config,
Expand Down Expand Up @@ -1575,6 +1576,12 @@ def test_setup_passwordless_ssh_access(
mock_setup_passwordless_ssh_access_to_cluster,
)

monkeypatch.setattr(
milatools.cli.init_command,
setup_keys_on_login_node.__name__,
Mock(spec=setup_keys_on_login_node),
)

result = setup_passwordless_ssh_access(ssh_config)

if not public_key_exists:
Expand Down
Loading

0 comments on commit e3bc550

Please sign in to comment.