From f6fe13ba27392be6b3c040617db816e683fa14ba Mon Sep 17 00:00:00 2001 From: Caroline Date: Thu, 12 Dec 2024 22:24:16 -0500 Subject: [PATCH] Refactor cluster factory --- runhouse/resources/hardware/cluster.py | 4 + .../resources/hardware/cluster_factory.py | 542 ++++++------------ runhouse/resources/hardware/constants.py | 36 ++ runhouse/resources/hardware/utils.py | 132 ++++- runhouse/resources/resource.py | 51 +- 5 files changed, 341 insertions(+), 424 deletions(-) create mode 100644 runhouse/resources/hardware/constants.py diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 7e057cdec..36808691a 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -445,6 +445,10 @@ def config(self, condensed: bool = True): return config + def _update_values(self, new_values: Dict[str, str]): + for key, val in new_values: + setattr(self, key, val) + def endpoint(self, external: bool = False): """Endpoint for the cluster's Daemon server. diff --git a/runhouse/resources/hardware/cluster_factory.py b/runhouse/resources/hardware/cluster_factory.py index 9ffd6dee8..6dc3a24be 100644 --- a/runhouse/resources/hardware/cluster_factory.py +++ b/runhouse/resources/hardware/cluster_factory.py @@ -1,44 +1,45 @@ -import os -import subprocess -import warnings - from typing import Dict, List, Optional, Union -from runhouse.constants import RESERVED_SYSTEM_NAMES -from runhouse.globals import rns_client - from runhouse.logger import get_logger -from runhouse.resources.hardware.utils import LauncherType, ServerConnectionType +from runhouse.resources.hardware.cluster import Cluster +from runhouse.resources.hardware.constants import ( + KUBERNETES_CLUSTER_ARGS, + ONDEMAND_COMPUTE_ARGS, + RH_SERVER_ARGS, +) +from runhouse.resources.hardware.on_demand_cluster import OnDemandCluster +from runhouse.resources.hardware.utils import ( + _compare_config_with_alt_options, + LauncherType, + ServerConnectionType, + setup_kubernetes, +) from runhouse.resources.images.image import Image -from .cluster import Cluster -from .on_demand_cluster import OnDemandCluster - logger = get_logger(__name__) -# Cluster factory method + def cluster( name: str, host: Union[str, List[str]] = None, ssh_creds: Union[Dict, str] = None, - server_port: int = None, - server_host: str = None, + server_port: Optional[int] = None, + server_host: Optional[str] = None, server_connection_type: Union[ServerConnectionType, str] = None, - launcher: Union[LauncherType, str] = None, - ssl_keyfile: str = None, - ssl_certfile: str = None, - domain: str = None, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[str] = None, + domain: Optional[str] = None, + image: Optional[Image] = None, den_auth: bool = None, - image: Image = None, load_from_den: bool = True, dryrun: bool = False, **kwargs, -) -> Union[Cluster, OnDemandCluster]: +): """ Builds an instance of :class:`Cluster`. Args: - name (str): Name for the cluster, to re-use later on. + name (str): Name for the cluster. host (str or List[str], optional): Hostname (e.g. domain or name in .ssh/config), IP address, or list of IP addresses for the cluster (the first of which is the head node). (Default: ``None``). ssh_creds (dict or str, optional): SSH credentials, passed as dictionary or the name of an `SSHSecret` object. @@ -52,19 +53,16 @@ def cluster( API server. ``ssh`` will use start with server via an SSH tunnel. ``tls`` will start the server with HTTPS on port 443 using TLS certs without an SSH tunnel. ``none`` will start the server with HTTP without an SSH tunnel. (Default: ``None``). - launcher (LauncherType or str, optional): Method for launching the cluster. If set to `local`, will launch - locally via Sky. If set to `den`, launching will be handled by Runhouse. Currently only relevant for - ondemand clusters and kubernetes clusters. (Default: ``local``). ssl_keyfile(str, optional): Path to SSL key file to use for launching the API server with HTTPS. (Default: ``None``). ssl_certfile(str, optional): Path to SSL certificate file to use for launching the API server with HTTPS. (Default: ``None``). domain(str, optional): Domain name for the cluster. Relevant if enabling HTTPs on the cluster. (Default: ``None``). + image (Image, optional): Default image containing setup steps to run during cluster setup. See :class:`Image`. + (Default: ``None``) den_auth (bool, optional): Whether to use Den authorization on the server. If ``True``, will validate incoming requests with a Runhouse token provided in the auth headers of the request with the format: ``{"Authorization": "Bearer "}``. (Default: ``None``). - image (Image, optional): Default image containing setup steps to run during cluster setup. See :class:`Image`. - (Default: ``None``) load_from_den (bool): Whether to try loading the Cluster resource from Den. (Default: ``True``) dryrun (bool): Whether to create the Cluster if it doesn't exist, or load a Cluster object as a dryrun. (Default: ``False``) @@ -91,216 +89,75 @@ def cluster( >>> # Load cluster from above >>> reloaded_cluster = rh.cluster(name="rh-a10x") """ - if host and kwargs.get("ips"): + cluster_args = locals() + cluster_args.pop("kwargs") + cluster_args = {k: v for k, v in cluster_args.items() if v is not None} + + # check for invalid args + valid_kwargs = ONDEMAND_COMPUTE_ARGS | KUBERNETES_CLUSTER_ARGS | RH_SERVER_ARGS + unsupported_kwargs = kwargs.keys() - valid_kwargs + if unsupported_kwargs: raise ValueError( - "Cluster factory method can only accept one of `host` or `ips` as an argument." + f"Received unsupported kwargs {unsupported_kwargs}. " + "Please refer to `rh.cluster` or `rh.ondemand_cluster` for valid input args." ) - if name: - alt_options = dict( - host=host, - ssh_creds=ssh_creds, - server_port=server_port, - server_host=server_host, - server_connection_type=server_connection_type, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - domain=domain, - den_auth=den_auth, - ) - alt_options.update(kwargs) - # Filter out None/default values - alt_options = {k: v for k, v in alt_options.items() if v is not None} - try: - c = Cluster.from_name( - name, - load_from_den=load_from_den, - dryrun=dryrun, - _alt_options=alt_options, + if kwargs.keys() & (ONDEMAND_COMPUTE_ARGS | KUBERNETES_CLUSTER_ARGS): + if host or ssh_creds: + raise ValueError( + "Received incompatible args specific to both a static and ondemand cluster." ) - if c: - c.set_connection_defaults() - # If the the user changed the image and wants to restart the server to apply the new - # changes, we need to update the image in the cluster object - c.image = image or c.image - if kwargs.get("autostop_mins") and c.autostop_mins != kwargs.get( - "autostop_mins" - ): - c.autostop_mins = kwargs.get("autostop_mins") - if launcher: - c.launcher = launcher - if den_auth: - c.save() - return c - except ValueError as e: - if not alt_options: - raise e - - # Check if any of the ondemand_cluster-specific arguments are in the kwargs - if {"instance_type", "memory", "disk_size", "num_cpus", "accelerators"} & set( - kwargs.keys() - ): return ondemand_cluster( - name=name, - server_port=server_port, - server_host=server_host, - server_connection_type=server_connection_type, - launcher=launcher, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - domain=domain, - den_auth=den_auth, - image=image, - dryrun=dryrun, - load_from_den=load_from_den, + **cluster_args, **kwargs, ) - if isinstance(host, str): - host = [host] - - ssh_creds = ssh_creds or rns_client.default_ssh_key - - c = Cluster( - ips=kwargs.pop("ips", None) or host, - creds=ssh_creds, - name=name, - server_host=server_host, - server_port=server_port, - server_connection_type=server_connection_type, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - domain=domain, - den_auth=den_auth, - image=image, - dryrun=dryrun, - **kwargs, - ) - c.set_connection_defaults(**kwargs) - - if den_auth or rns_client.autosave_resources(): - c.save() - - return c - - -def kubernetes_cluster( - name: str, - instance_type: str = None, - namespace: str = None, - kube_config_path: str = None, - context: str = None, - server_connection_type: Union[ServerConnectionType, str] = None, - launcher: Union[LauncherType, str] = None, - **kwargs, -) -> OnDemandCluster: - - # if user passes provider via kwargs to kubernetes_cluster - provider_passed = kwargs.pop("provider", None) - - if provider_passed is not None and provider_passed != "kubernetes": - raise ValueError( - f"Runhouse K8s Cluster provider must be `kubernetes`. " - f"You passed {provider_passed}." + try: + new_cluster = Cluster.from_name( + name, load_from_den=load_from_den, dryrun=dryrun ) - - # checking server_connection_type passed over from ondemand_cluster factory method - if ( - server_connection_type is not None - and server_connection_type != ServerConnectionType.SSH - ): - raise ValueError( - f"Runhouse K8s Cluster server connection type must be set to `ssh`. " - f"You passed {server_connection_type}." - ) - - if context is not None and namespace is not None: - warnings.warn( - "You passed both a context and a namespace. Ensure your namespace matches the one in your context.", - UserWarning, - ) - - if namespace is not None and launcher == "local": - # Set the context only if launching locally - # check if user passed a user-defined namespace - cmd = f"kubectl config set-context --current --namespace={namespace}" - try: - process = subprocess.run( - cmd, - shell=True, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, + if isinstance(new_cluster, OnDemandCluster): + # load from name, none of the other arguments were provided + cluster_type = "ondemand" + else: + cluster_type = "static" + except ValueError: + new_cluster = None + cluster_type = "unsaved" + + if cluster_type == "unsaved": + if not cluster_args: + raise ValueError( + f"Cluster {name} not found in Den. Must provide cluster arguments to construct " + "a new cluster object." ) - logger.debug(process.stdout) - logger.info(f"Kubernetes namespace set to {namespace}") - - except subprocess.CalledProcessError as e: - logger.info(f"Error: {e}") - - if ( - kube_config_path is not None - ): # check if user passed a user-defined kube_config_path - kube_config_dir = os.path.expanduser("~/.kube") - kube_config_path_rl = os.path.join(kube_config_dir, "config") - - if not os.path.exists( - kube_config_dir - ): # check if ~/.kube directory exists on local machine - try: - os.makedirs( - kube_config_dir - ) # create ~/.kube directory if it doesn't exist - logger.info(f"Created directory: {kube_config_dir}") - except OSError as e: - logger.info(f"Error creating directory: {e}") - - if os.path.exists(kube_config_path_rl): - raise Exception( - "A kubeconfig file already exists in ~/.kube directory. Aborting." + new_cluster = Cluster(**cluster_args) + elif cluster_type == "static": + if cluster.is_up(): + mismatches = _compare_config_with_alt_options( + new_cluster.config(), cluster_args ) + server_mismatches = mismatches.keys() & RH_SERVER_ARGS + if server_mismatches: + logger.warning( + "Runhouse server setting has been updated. Please run `cluster.restart_server()` " + f"to apply new server settings for {server_mismatches}" + ) + new_cluster = Cluster(**cluster_args) - try: - cmd = f"cp {kube_config_path} {kube_config_path_rl}" # copy user-defined kube_config to ~/.kube/config - subprocess.run(cmd, shell=True, check=True) - logger.info(f"Copied kubeconfig to: {kube_config_path}") - except subprocess.CalledProcessError as e: - logger.info(f"Error copying kubeconfig: {e}") - - if context is not None and launcher == "local": - # check if user passed a user-defined context - try: - cmd = f"kubectl config use-context {context}" # set user-defined context as current context - subprocess.run(cmd, shell=True, check=True) - logger.info(f"Kubernetes context has been set to: {context}") - except subprocess.CalledProcessError as e: - logger.info(f"Error setting context: {e}") + if den_auth: + new_cluster.save() + return new_cluster - c = OnDemandCluster( - name=name, - instance_type=instance_type, - provider="kubernetes", - launcher=launcher, - server_connection_type=server_connection_type, - namespace=namespace, - context=context, - **kwargs, - ) - c.set_connection_defaults() - return c - - -# OnDemandCluster factory method def ondemand_cluster( - name: str, + name, + # sky arguments instance_type: Optional[str] = None, num_nodes: Optional[int] = None, provider: Optional[str] = None, autostop_mins: Optional[int] = None, use_spot: bool = False, - image: Image = None, region: Optional[str] = None, memory: Union[int, str, None] = None, disk_size: Union[int, str, None] = None, @@ -308,47 +165,58 @@ def ondemand_cluster( accelerators: Union[int, str, None] = None, open_ports: Union[int, str, List[int], None] = None, sky_kwargs: Dict = None, + # kubernetes related arguments + namespace: Optional[str] = None, + kube_config_path: Optional[str] = None, + context: Optional[str] = None, + # runhouse server arguments server_port: int = None, - server_host: int = None, + server_host: str = None, server_connection_type: Union[ServerConnectionType, str] = None, - launcher: Union[LauncherType, str] = None, ssl_keyfile: str = None, ssl_certfile: str = None, domain: str = None, + image: Image = None, + # misc arguments + launcher: Union[LauncherType, str] = None, den_auth: bool = None, load_from_den: bool = True, dryrun: bool = False, - **kwargs, -) -> OnDemandCluster: +): """ - Builds an instance of :class:`OnDemandCluster`. Note that region, memory, disk_size, and open_ports - are all passed through to SkyPilot's `Resource constructor - `__. + Builds an instance of :class:`OnDemandCluster`. + + Note that the following args are used for launching, and passed through to SkyPilot's + `Resource constructor `__: + instance_type, num_nodes provider, use_spot, region, memory, disk_size, num_cpus, accelerators, + open_ports, autostop_mins, sky_kwargs Args: name (str): Name for the cluster, to re-use later on. - instance_type (int, optional): Type of cloud VM type to use for the cluster, e.g. "r5d.xlarge". Optional, as - may instead choose to specify resource requirements (e.g. memory, disk_size, num_cpus, accelerators). + instance_type (int, optional): Type of cloud VM type to use for the cluster, e.g. "r5d.xlarge". + Optional, as may instead choose to specify resource requirements (e.g. memory, disk_size, + num_cpus, accelerators). num_nodes (int, optional): Number of nodes to use for the cluster. provider (str, optional): Cloud provider to use for the cluster. autostop_mins (int, optional): Number of minutes to keep the cluster up after inactivity, or ``-1`` to keep cluster up indefinitely. (Default: ``60``). use_spot (bool, optional): Whether or not to use spot instance. region (str, optional): The region to use for the cluster. - memory (int or str, optional): Amount of memory to use for the cluster, e.g. "16" or "16+". - disk_size (int or str, optional): Amount of disk space to use for the cluster, e.g. "100" or "100+". - num_cpus (int or str, optional): Number of CPUs to use for the cluster, e.g. "4" or "4+". + memory (int or str, optional): Amount of memory to use for the cluster, e.g. `16` or "16+". + disk_size (int or str, optional): Amount of disk space to use for the cluster, e.g. `100` or "100+". + num_cpus (int or str, optional): Number of CPUs to use for the cluster, e.g. `4` or "4+". accelerators (int or str, optional): Number of accelerators to use for the cluster, e.g. "A101" or "L4:8". open_ports (int or str or List[int], optional): Ports to open in the cluster's security group. Note that you are responsible for ensuring that the applications listening on these ports are secure. - sky_kwargs (dict, optional): Additional keyword arguments to pass to the SkyPilot `Resource` or - `launch` APIs. Should be a dict of the form - `{"resources": {}, "launch": {}}`, where resources_kwargs and - launch_kwargs will be passed to the SkyPilot Resources API - (See `SkyPilot docs `__) - and `launch` API (See - `SkyPilot docs `__), respectively. - Any arguments which duplicate those passed to the `ondemand_cluster` factory method will raise an error. + sky_kwargs (dict, optional): Additional keyword arguments to pass to the SkyPilot `Resource` or `launch` + APIs. Should be a dict of the form `{"resources": {}, "launch": {}}`, + where resources_kwargs and launch_kwargs will be passed to the SkyPilot Resources API (See + `SkyPilot docs `__) and `launch` + API (See `SkyPilot docs `__), + respectively. Duplicating arguments passed to the `ondemand_cluster` factory method will raise an error. + namespace (str, optional): Namespace for kubernetes cluster, if applicable. + kube_config_path (str, optional): Path to the kube_config, for a kubernetes cluster. + context (str, optional): Context for kubernetes cluster, if applicable. server_port (bool, optional): Port to use for the server. If not provided will use 80 for a ``server_connection_type`` of ``none``, 443 for ``tls`` and ``32300`` for all other SSH connection types. server_host (bool, optional): Host from which the server listens for traffic (i.e. the --host argument @@ -363,11 +231,11 @@ def ondemand_cluster( ssl_keyfile(str, optional): Path to SSL key file to use for launching the API server with HTTPS. ssl_certfile(str, optional): Path to SSL certificate file to use for launching the API server with HTTPS. domain(str, optional): Domain name for the cluster. Relevant if enabling HTTPs on the cluster. + image (Image, optional): Default image containing setup steps to run during cluster setup. See :class:`Image`. + (Default: ``None``) den_auth (bool, optional): Whether to use Den authorization on the server. If ``True``, will validate incoming requests with a Runhouse token provided in the auth headers of the request with the format: ``{"Authorization": "Bearer "}``. (Default: ``None``). - image (Image, optional): Default image containing setup steps to run during cluster setup. See :class:`Image`. - (Default: ``None``) load_from_den (bool): Whether to try loading the Cluster resource from Den. (Default: ``True``) dryrun (bool): Whether to create the Cluster if it doesn't exist, or load a Cluster object as a dryrun. (Default: ``False``) @@ -389,145 +257,67 @@ def ondemand_cluster( >>> # Load cluster from above >>> reloaded_cluster = rh.ondemand_cluster(name="rh-4-a100s") """ + cluster_args = locals() + cluster_args = {k: v for k, v in cluster_args.items() if v is not None} - if name in RESERVED_SYSTEM_NAMES: - raise ValueError( - f"Cluster name {name} is a reserved name. Please use a different name which is not one of " - f"{RESERVED_SYSTEM_NAMES}." - ) - - if "num_instances" in kwargs: - logger.warning( - "The `num_instances` argument is deprecated and will be removed in a future version. " - "Please use the argument `num_nodes` instead to refer to the number of nodes for the cluster." - ) - num_nodes = kwargs.get("num_instances") - - if launcher and launcher not in LauncherType.__members__.values(): - raise ValueError( - f"Invalid launcher type {launcher}. Specify either 'den' or 'local' " - f"in the cluster factory or add a `launcher` field to your " - f"local ~/.rh/config.yaml." - ) - - if instance_type and any([memory, disk_size, num_cpus]): - raise ValueError( - "Resources are over-specified. Cannot specify both `instance_type` and any of `memory`, `disk_size`, " - "or `num_cpus`." + try: + new_cluster = Cluster.from_name( + name, load_from_den=load_from_den, dryrun=dryrun ) - - if name: - alt_options = dict( - instance_type=instance_type, - num_nodes=num_nodes, - provider=provider, - region=region, - memory=memory, - disk_size=disk_size, - num_cpus=num_cpus, - accelerators=accelerators, - open_ports=open_ports, - server_host=server_host, - server_port=server_port, - server_connection_type=server_connection_type, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - domain=domain, - den_auth=den_auth, - ) - # Filter out None/default values - alt_options = {k: v for k, v in alt_options.items() if v is not None} - try: - c = Cluster.from_name( - name, - load_from_den=load_from_den, - dryrun=dryrun, - _alt_options=alt_options, + cluster_type = "ondemand" + except ValueError: + new_cluster = None + cluster_type = "unsaved" + + if cluster_args.keys() & KUBERNETES_CLUSTER_ARGS: + setup_kubernetes(**cluster_args) + + if cluster_type == "unsaved": + if not cluster_args: + raise ValueError( + f"OndemandCluster {name} not found in Den. Must provide cluster arguments to construct " + "a new cluster object." ) - if c: - c.set_connection_defaults() - # If the the user changed the image and wants to restart the server to apply the new - # changes, we need to update the image in the cluster object - c.image = image or c.image - if autostop_mins and autostop_mins != c.autostop_mins: - c.autostop_mins = autostop_mins - c.launcher = launcher or c.launcher - if den_auth: - c.save() - return c - except ValueError as e: - if launcher == LauncherType.LOCAL: - import sky - - state = sky.status(cluster_names=[name], refresh=False) - if len(state) == 0 and not alt_options: - raise e - - if provider == "kubernetes": - namespace = kwargs.pop("namespace", None) - kube_config_path = kwargs.pop("kube_config_path", None) - context = kwargs.pop("context", None) - server_connection_type = kwargs.pop("server_connection_type", None) - - return kubernetes_cluster( - name=name, - instance_type=instance_type, - namespace=namespace, - kube_config_path=kube_config_path, - context=context, - server_connection_type=server_connection_type, - launcher=launcher, - autostop_mins=autostop_mins, - num_nodes=num_nodes, - provider=provider, - use_spot=use_spot, - image=image, - region=region, - memory=memory, - disk_size=disk_size, - num_cpus=num_cpus, - accelerators=accelerators, - open_ports=open_ports, - sky_kwargs=sky_kwargs, - server_port=server_port, - server_host=server_host, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - domain=domain, - den_auth=den_auth, - dryrun=dryrun, - **kwargs, + new_cluster = OnDemandCluster(**cluster_args) + elif cluster_type == "ondemand": + mismatches = _compare_config_with_alt_options( + new_cluster.config(), cluster_args ) - - c = OnDemandCluster( - instance_type=instance_type, - provider=provider, - num_nodes=num_nodes, - autostop_mins=autostop_mins, - use_spot=use_spot, - image=image, - region=region, - memory=memory, - disk_size=disk_size, - num_cpus=num_cpus, - accelerators=accelerators, - open_ports=open_ports, - sky_kwargs=sky_kwargs, - server_host=server_host, - server_port=server_port, - server_connection_type=server_connection_type, - launcher=launcher, - ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - domain=domain, - den_auth=den_auth, - name=name, - dryrun=dryrun, - **kwargs, - ) - c.set_connection_defaults() - - if den_auth or rns_client.autosave_resources(): - c.save() - - return c + compute_mismatches = { + k: v + for k, v in mismatches.items() + if k in (ONDEMAND_COMPUTE_ARGS | KUBERNETES_CLUSTER_ARGS) + } + server_mismatches = {k: v for k, v in mismatches.items() if k in RH_SERVER_ARGS} + new_autostop_mins = compute_mismatches.pop("autostop_mins", None) + + if new_cluster.is_up(): + # cluster is up - throw error if launch compute mismatches, but allow server / autostop min updates + if compute_mismatches: + raise ValueError( + f"Cluster {name} is up, but received argument mismatches for compute: {compute_mismatches.keys()}. " + "Please construct a new cluster object or ensure that the arguments match." + ) + if new_autostop_mins: + logger.info(f"Updating autostop mins for cluster {name}") + new_cluster.autostop_mins = new_autostop_mins + if server_mismatches: + new_cluster._update_values(server_mismatches) + logger.warning( + "Runhouse server setting has been updated. Please run `cluster.restart_server()` " + f"to apply new server settings for {server_mismatches.keys()}" + ) + else: + # cluster is down + # - construct new cluster if launch compute mismatches + # - if compute matches/empty but server or autostop mismatches, override just those values + if compute_mismatches: + new_cluster = OnDemandCluster(**cluster_args) + elif server_mismatches: + new_cluster._update_values(server_mismatches) + if new_autostop_mins: + new_cluster._autostop_mins = new_autostop_mins + + if den_auth: + new_cluster.save() + return new_cluster diff --git a/runhouse/resources/hardware/constants.py b/runhouse/resources/hardware/constants.py new file mode 100644 index 000000000..88489f4c4 --- /dev/null +++ b/runhouse/resources/hardware/constants.py @@ -0,0 +1,36 @@ +STATIC_CLUSTER_ARGS = { + "host", + "ssh_creds", +} + +ONDEMAND_COMPUTE_ARGS = { + "instance_type", + "num_nodes", + "provider", + "use_spot", + "region", + "memory", + "disk_size", + "num_cpus", + "accelerators", + "sky_kwargs", + "launcher", + "autostop_mins", +} + +KUBERNETES_CLUSTER_ARGS = { + "kube_context", + "kube_namespace", + "kube_config_path", +} + +RH_SERVER_ARGS = { + "server_port", + "server_host", + "open_ports", # ondemand only + "server_connection_type", + "ssl_keyfile", + "ssl_certfile", + "domain", + "image", +} diff --git a/runhouse/resources/hardware/utils.py b/runhouse/resources/hardware/utils.py index 37067666f..f7b257f32 100644 --- a/runhouse/resources/hardware/utils.py +++ b/runhouse/resources/hardware/utils.py @@ -20,7 +20,7 @@ SKY_VENV, TIME_UNITS, ) -from runhouse.globals import rns_client +from runhouse.globals import configs, rns_client from runhouse.logger import get_logger from runhouse.resources.hardware.sky.command_runner import ( @@ -81,6 +81,47 @@ def load_cluster_config_from_file() -> Dict: return {} +def _compare_config_with_alt_options(config, alt_options, return_config=False): + """Overload by child resources to compare their config with the alt_options. If the user specifies alternate + options, compare the config with the options. It's generally up to the child class to decide how to handle the + options, but default behavior is provided. The default behavior simply checks if any of the alt_options are + present in the config (with awareness of resources), and if their values differ, return None. + + If the child class returns None, it's deciding to override the config + with the options. If the child class returns a config, it's deciding to use the config and ignore the options + (or somehow incorporate them, rarely). Note that if alt_options are provided and the config is not found, + no error is raised, while if alt_options are not provided and the config is not found, an error is raised. + """ + + def alt_option_to_str(val): + if isinstance(val, dict): + # This can either be a sub-resource which hasn't been converted to a resource yet, or an + # actual user-provided dict + if "rns_address" in val: + return val["rns_address"] + if "name" in val: + # convert a user-provided name to an rns_address + return rns_client.resolve_rns_path(val["name"]) + else: + return val + elif isinstance(val, list): + val = [str(item) if isinstance(item, int) else item for item in val] + elif isinstance(val, int) or isinstance(val, float): + val = str(val) + return val + + mismatches = {} + for key, value in alt_options.items(): + if key in config: + if alt_option_to_str(value) != alt_option_to_str(config[key]): + if return_config: + return None + mismatches[key] = value + elif return_config: + return None + return mismatches + + def _current_cluster(key="config"): """Retrieve key value from the current cluster config. If key is "config", returns entire config.""" @@ -711,3 +752,92 @@ def __str__(self): if self.retry: s += ", retry in {0}ms".format(self.retry) return s + + +################################### +# KUBERNETES SETUP +################################### +def setup_kubernetes( + namespace: Optional[str] = None, + kube_config_path: Optional[str] = None, + context: Optional[str] = None, + **kwargs, +): + if kwargs.get("provider") and not kwargs.get("provider") == "kubernetes": + raise ValueError( + f"Recieved non kubernetes provider {kwargs.get('provider')} with kubernetes specific " + "cluster arguments." + ) + + if ( + kwargs.get("server_connection_type") + and kwargs.get("server_connection_type") != ServerConnectionType.SSH + ): + raise ValueError( + "Runhouse K8s Cluster server connection type must be set to `ssh`. " + f"You passed {kwargs.getserver_connection_type}." + ) + + if context and namespace: + logger.warning( + "You passed both a context and a namespace. Ensure your namespace matches the one in your context.", + ) + + launcher = kwargs.get("launcher") or configs.launcher + if namespace and launcher == "local": + # Set the context only if launching locally + # check if user passed a user-defined namespace + cmd = f"kubectl config set-context --current --namespace={namespace}" + try: + process = subprocess.run( + cmd, + shell=True, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + logger.debug(process.stdout) + logger.info(f"Kubernetes namespace set to {namespace}") + + except subprocess.CalledProcessError as e: + logger.info(f"Error: {e}") + + # Q - should the following also only be set up for local launcher? + if ( + kube_config_path is not None + ): # check if user passed a user-defined kube_config_path + kube_config_dir = os.path.expanduser("~/.kube") + kube_config_path_rl = os.path.join(kube_config_dir, "config") + + if not os.path.exists( + kube_config_dir + ): # check if ~/.kube directory exists on local machine + try: + os.makedirs( + kube_config_dir + ) # create ~/.kube directory if it doesn't exist + logger.info(f"Created directory: {kube_config_dir}") + except OSError as e: + logger.info(f"Error creating directory: {e}") + + if os.path.exists(kube_config_path_rl): + raise Exception( + "A kubeconfig file already exists in ~/.kube directory. Aborting." + ) + + try: + cmd = f"cp {kube_config_path} {kube_config_path_rl}" # copy user-defined kube_config to ~/.kube/config + subprocess.run(cmd, shell=True, check=True) + logger.info(f"Copied kubeconfig to: {kube_config_path}") + except subprocess.CalledProcessError as e: + logger.info(f"Error copying kubeconfig: {e}") + + if context is not None and launcher == "local": + # check if user passed a user-defined context + try: + cmd = f"kubectl config use-context {context}" # set user-defined context as current context + subprocess.run(cmd, shell=True, check=True) + logger.info(f"Kubernetes context has been set to: {context}") + except subprocess.CalledProcessError as e: + logger.info(f"Error setting context: {e}") diff --git a/runhouse/resources/resource.py b/runhouse/resources/resource.py index 7a3e00d8b..7d91c77c8 100644 --- a/runhouse/resources/resource.py +++ b/runhouse/resources/resource.py @@ -7,6 +7,7 @@ from runhouse.globals import obj_store, rns_client from runhouse.logger import get_logger +from runhouse.resources.hardware.utils import _compare_config_with_alt_options from runhouse.rns.top_level_rns_fns import ( resolve_rns_path, save, @@ -198,52 +199,6 @@ def _check_for_child_configs(cls, config: dict): """Overload by child resources to load any resources they hold internally.""" return config - @classmethod - def _compare_config_with_alt_options(cls, config, alt_options): - """Overload by child resources to compare their config with the alt_options. If the user specifies alternate - options, compare the config with the options. It's generally up to the child class to decide how to handle the - options, but default behavior is provided. The default behavior simply checks if any of the alt_options are - present in the config (with awareness of resources), and if their values differ, return None. - - If the child class returns None, it's deciding to override the config - with the options. If the child class returns a config, it's deciding to use the config and ignore the options - (or somehow incorporate them, rarely). Note that if alt_options are provided and the config is not found, - no error is raised, while if alt_options are not provided and the config is not found, an error is raised. - """ - - def alt_option_to_str(val): - if isinstance(val, Resource): - if not val.rns_address and val.name and config.get("name"): - # If rns_address is missing, try current resource folder - _, folder = rns_client.split_rns_name_and_path( - rns_client.resolve_rns_path(config.get("name")) - ) - return f"{folder}/{val.name}" - return val.rns_address - elif isinstance(val, dict): - # This can either be a sub-resource which hasn't been converted to a resource yet, or an - # actual user-provided dict - if "rns_address" in val: - return val["rns_address"] - if "name" in val: - # convert a user-provided name to an rns_address - return rns_client.resolve_rns_path(val["name"]) - else: - return val - elif isinstance(val, list): - val = [str(item) if isinstance(item, int) else item for item in val] - elif isinstance(val, int) or isinstance(val, float): - val = str(val) - return val - - for key, value in alt_options.items(): - if key in config: - if alt_option_to_str(value) != alt_option_to_str(config[key]): - return None - else: - return None - return config - @classmethod def from_name( cls, @@ -269,7 +224,9 @@ def from_name( config = rns_client.load_config(name=name, load_from_den=load_from_den) if _alt_options: - config = cls._compare_config_with_alt_options(config, _alt_options) + config = _compare_config_with_alt_options( + config, _alt_options, return_config=True + ) if not config: return None if not config: