From 0bd96b0ec2fb8f940c42bad537ddd2c2fdfea763 Mon Sep 17 00:00:00 2001 From: jlewitt1 Date: Sat, 2 Nov 2024 19:09:53 +0200 Subject: [PATCH] default ssh provider updates --- runhouse/resources/hardware/cluster.py | 19 ++++++++++++++++--- .../secrets/provider_secrets/ssh_secret.py | 3 ++- runhouse/resources/secrets/secret_factory.py | 12 +++++++++--- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 0f2b040128..380dcab377 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -275,7 +275,8 @@ def delete_configs(self, delete_creds: bool = False): super().delete_configs() def _setup_creds(self, ssh_creds: Union[Dict, "Secret", str]): - """Setup cluster credentials from user provided ssh_creds""" + """Setup cluster credentials from user provided ssh_creds. If no creds are provided, try using + the default SSH creds saved in Den.""" from runhouse.resources.secrets import Secret from runhouse.resources.secrets.provider_secrets.sky_secret import SkySecret from runhouse.resources.secrets.provider_secrets.ssh_secret import SSHSecret @@ -283,15 +284,27 @@ def _setup_creds(self, ssh_creds: Union[Dict, "Secret", str]): if not hasattr(self, "_creds"): self._creds = None - if not ssh_creds: - return elif isinstance(ssh_creds, Secret): self._creds = ssh_creds return + elif isinstance(ssh_creds, str): self._creds = Secret.from_name(ssh_creds) return + if not ssh_creds: + from runhouse import ProviderSecret + + try: + # Use the default ssh creds if saved in Den + ssh_secret = ProviderSecret.from_name("ssh") + self._creds = ssh_secret + return + except ValueError: + pass + + return + creds = ( copy.copy(ssh_creds) if isinstance(ssh_creds, Dict) else (ssh_creds or {}) ) diff --git a/runhouse/resources/secrets/provider_secrets/ssh_secret.py b/runhouse/resources/secrets/provider_secrets/ssh_secret.py index 7a22c35342..a9c93aef77 100644 --- a/runhouse/resources/secrets/provider_secrets/ssh_secret.py +++ b/runhouse/resources/secrets/provider_secrets/ssh_secret.py @@ -56,7 +56,8 @@ def save( if name: self.name = name elif not self.name: - self.name = f"ssh-{self.key}" + # If name not provided treat as the "default" SSH secret + self.name = self._PROVIDER return super().save( save_values=save_values, headers=headers or rns_client.request_headers(), diff --git a/runhouse/resources/secrets/secret_factory.py b/runhouse/resources/secrets/secret_factory.py index a5a4adbb86..8fd6890632 100644 --- a/runhouse/resources/secrets/secret_factory.py +++ b/runhouse/resources/secrets/secret_factory.py @@ -19,6 +19,7 @@ def secret( name (str, optional): Name to assign the secret resource. values (Dict, optional): Dictionary of secret key-value pairs. load_from_den (bool): Whether to try loading the secret from Den. (Default: ``True``) + provider (str): Provider corresponding to the secret. dryrun (bool, optional): Whether to create in dryrun mode. (Default: False) Returns: Secret: The resulting Secret object. @@ -81,13 +82,18 @@ def provider_secret( return Secret.from_name(name, load_from_den=load_from_den) elif not any([values, path, env_vars]): + # try reloading by name or provider + try: + return ProviderSecret.from_name( + name or provider, load_from_den=load_from_den + ) + except ValueError: + pass if provider in Secret.builtin_providers(as_str=True): secret_class = _get_provider_class(provider) return secret_class(name=name, provider=provider, dryrun=dryrun) else: - return ProviderSecret.from_name( - name or provider, load_from_den=load_from_den - ) + raise ValueError(f"Provider {provider} not recognized.") elif sum([bool(x) for x in [values, path, env_vars]]) == 1: secret_class = _get_provider_class(provider)