diff --git a/src/levanter/infra/tpus.py b/src/levanter/infra/tpus.py index c47d79d46..c635cc88b 100644 --- a/src/levanter/infra/tpus.py +++ b/src/levanter/infra/tpus.py @@ -203,8 +203,14 @@ def add_ssh_key(ssh_key_filename): # have to make sure .ssh exists os.makedirs(os.path.expanduser("~/.ssh"), exist_ok=True) try: - key_hash = subprocess.check_output(["ssh-keygen", "-lf", ssh_key_filename]).decode("utf-8").split()[1] - existing_keys = subprocess.check_output(["ssh-add", "-l"]).decode("utf-8").split("\n") + key_hash = ( + subprocess.check_output(["ssh-keygen", "-lf", ssh_key_filename], stderr=subprocess.STDOUT) + .decode("utf-8") + .split()[1] + ) + existing_keys = ( + subprocess.check_output(["ssh-add", "-l"], stderr=subprocess.STDOUT).decode("utf-8").split("\n") + ) for key in existing_keys: if key_hash in key: return