From c9c9005f44ea11c686467e13282c519aed433d0a Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Tue, 26 Nov 2024 13:13:48 +0000 Subject: [PATCH 1/6] e2e with working implementation and example of the Ray dynamic config --- dev/dags/ray_dynamic_config.py | 195 ++++++++ dev/dags/ray_single_operator.py | 2 +- dev/dags/ray_taskflow_example.py | 2 +- .../ray_taskflow_example_existing_cluster.py | 4 +- dev/dags/scripts/ray.yaml | 11 +- dev/dags/setup-teardown.py | 3 +- dev/docker-compose.override.yml | 21 + dev/kind-config.yaml | 21 + .../local_development_setup.rst | 209 ++++++++- ray_provider/__init__.py | 2 +- ray_provider/constants.py | 4 + .../{decorators/ray.py => decorators.py} | 102 ++-- ray_provider/decorators/__init__.py | 0 ray_provider/{hooks/ray.py => hooks.py} | 443 ++++++++++++++---- ray_provider/hooks/__init__.py | 0 .../{operators/ray.py => operators.py} | 130 ++--- ray_provider/operators/__init__.py | 0 ray_provider/{triggers/ray.py => triggers.py} | 75 +-- ray_provider/triggers/__init__.py | 0 tests/decorators/__init__.py | 0 tests/hooks/__init__.py | 0 tests/operators/__init__.py | 0 tests/{decorators => }/test_ray_decorators.py | 19 +- tests/{hooks => }/test_ray_hooks.py | 222 ++++----- tests/{operators => }/test_ray_operators.py | 22 +- tests/{triggers => }/test_ray_triggers.py | 40 +- tests/triggers/__init__.py | 0 27 files changed, 1123 insertions(+), 404 deletions(-) create mode 100644 dev/dags/ray_dynamic_config.py create mode 100644 dev/docker-compose.override.yml create mode 100644 dev/kind-config.yaml create mode 100644 ray_provider/constants.py rename ray_provider/{decorators/ray.py => decorators.py} (67%) delete mode 100644 ray_provider/decorators/__init__.py rename ray_provider/{hooks/ray.py => hooks.py} (64%) delete mode 100644 ray_provider/hooks/__init__.py rename ray_provider/{operators/ray.py => operators.py} (77%) delete mode 100644 ray_provider/operators/__init__.py rename ray_provider/{triggers/ray.py => triggers.py} (70%) delete mode 100644 ray_provider/triggers/__init__.py delete mode 100644 tests/decorators/__init__.py delete mode 100644 tests/hooks/__init__.py delete mode 100644 tests/operators/__init__.py rename tests/{decorators => }/test_ray_decorators.py (92%) rename tests/{hooks => }/test_ray_hooks.py (79%) rename tests/{operators => }/test_ray_operators.py (96%) rename tests/{triggers => }/test_ray_triggers.py (85%) delete mode 100644 tests/triggers/__init__.py diff --git a/dev/dags/ray_dynamic_config.py b/dev/dags/ray_dynamic_config.py new file mode 100644 index 0000000..0045053 --- /dev/null +++ b/dev/dags/ray_dynamic_config.py @@ -0,0 +1,195 @@ +""" +This example illustrates three DAGs. One + +The parent DAG (ray_dynamic_config_upstream_dag) uses TriggerDagRunOperator to trigger the other two: +* ray_dynamic_config_downstream_dag_1 +* ray_dynamic_config_downstream_dag_2 + +Each downstream DAG retrieves the context data (run_context) from dag_run.conf, which is passed by the parent DAG. + +The print_context tasks in the downstream DAGs output the received context to the logs. +""" +from pathlib import Path +import re + +from airflow import DAG +from airflow.decorators import dag, task +from airflow.operators.empty import EmptyOperator +from airflow.operators.python import PythonOperator +from airflow.operators.trigger_dagrun import TriggerDagRunOperator +from airflow.utils.context import Context +from airflow.utils.dates import days_ago +from jinja2 import Template +import yaml + +from ray_provider.decorators import ray + + +CONN_ID = "ray_conn" +RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml" +FOLDER_PATH = Path(__file__).parent / "ray_scripts" +RAY_TASK_CONFIG = { + "conn_id": CONN_ID, + "runtime_env": {"working_dir": str(FOLDER_PATH), "pip": ["numpy"]}, + "num_cpus": 1, + "num_gpus": 0, + "memory": 0, + "poll_interval": 5, + "ray_cluster_yaml": str(RAY_SPEC), + "xcom_task_key": "dashboard", +} + + +def slugify(value): + """ + Replace invalid characters with hyphens and make lowercase. + """ + return re.sub(r'[^\w\-\.]', '-', value).lower() + + +def create_config_from_context(context, **kwargs): + default_name = "{{ dag.dag_id }}-{{ dag_run.id }}" + + raycluster_name_template = context.get("dag_run").conf.get("raycluster_name", default_name) + raycluster_name = Template(raycluster_name_template).render(context).replace("_", "-") + raycluster_name = slugify(raycluster_name) + + raycluster_k8s_yml_filename_template = context.get("dag_run").conf.get("raycluster_k8s_yml_filename", default_name + ".yml") + raycluster_k8s_yml_filename = Template(raycluster_k8s_yml_filename_template).render(context).replace("_", "-") + raycluster_k8s_yml_filename = slugify(raycluster_k8s_yml_filename) + + with open(RAY_SPEC, "r") as file: + data = yaml.safe_load(file) + data["metadata"]["name"] = raycluster_name + + NEW_RAY_K8S_SPEC = Path(__file__).parent / "scripts" / raycluster_k8s_yml_filename + with open(NEW_RAY_K8S_SPEC, "w") as file: + yaml.safe_dump(data, file, default_flow_style=False) + + config = dict(RAY_TASK_CONFIG) + config["ray_cluster_yaml"] = str(NEW_RAY_K8S_SPEC) + return config + + +def print_context(**kwargs): + # Retrieve `conf` passed from the parent DAG + print(kwargs) + cluster_name = kwargs.get("dag_run").conf.get("raycluster_name", "No ray cluster name provided") + raycluster_k8s_yml_filename = kwargs.get("dag_run").conf.get("raycluster_k8s_yml_filename", "No ray cluster YML filename provided") + print(f"Received cluster name: {cluster_name}") + print(f"Received cluster K8s YML filename: {raycluster_k8s_yml_filename}") + + +# Downstream 1 +with DAG( + dag_id="ray_dynamic_config_child_1", + start_date=days_ago(1), + schedule_interval=None, + catchup=False, +) as dag: + + print_context_task = PythonOperator( + task_id="print_context", + python_callable=print_context, + ) + print_context_task + + @task + def generate_data(): + return [1, 2, 3] + + @ray.task(config=create_config_from_context) + def process_data_with_ray(data): + import numpy as np + import ray + + @ray.remote + def cubic(x): + return x**3 + + ray.init() + data = np.array(data) + futures = [cubic.remote(x) for x in data] + results = ray.get(futures) + mean = np.mean(results) + print(f"Mean of this population is {mean}") + return mean + + data = generate_data() + process_data_with_ray(data) + + +# Downstream 2 +with DAG( + dag_id="ray_dynamic_config_child_2", + start_date=days_ago(1), + schedule_interval=None, + catchup=False, +) as dag: + + print_context_task = PythonOperator( + task_id="print_context", + python_callable=print_context, + ) + + @task + def generate_data(): + return [1, 2, 3] + + @ray.task(config=create_config_from_context) + def process_data_with_ray(data): + import numpy as np + import ray + + @ray.remote + def square(x): + return x**2 + + ray.init() + data = np.array(data) + futures = [square.remote(x) for x in data] + results = ray.get(futures) + mean = np.mean(results) + print(f"Mean of this population is {mean}") + return mean + + + data = generate_data() + process_data_with_ray(data) + + +# Upstream +with DAG( + dag_id="ray_dynamic_config_parent", + start_date=days_ago(1), + schedule_interval=None, + catchup=False, +) as dag: + empty_task = EmptyOperator(task_id="empty_task") + + trigger_dag_1 = TriggerDagRunOperator( + task_id="trigger_downstream_dag_1", + trigger_dag_id="ray_dynamic_config_child_1", + conf={ + "raycluster_name": "first-{{ dag_run.id }}", + "raycluster_k8s_yml_filename": "first-{{ dag_run.id }}.yaml" + }, + ) + + trigger_dag_2 = TriggerDagRunOperator( + task_id="trigger_downstream_dag_2", + trigger_dag_id="ray_dynamic_config_child_2", + conf={}, + ) + + # Illustrates that by default two DAG runs of the same DAG will be using different Ray clusters + # Disabled because in the local dev MacOS we're only managing to spin up two Ray Cluster services concurrently + #trigger_dag_3 = TriggerDagRunOperator( + # task_id="trigger_downstream_dag_3", + # trigger_dag_id="ray_dynamic_config_child_2", + # conf={}, + #) + + empty_task >> trigger_dag_1 + trigger_dag_1 >> trigger_dag_2 + #trigger_dag_1 >> trigger_dag_3 diff --git a/dev/dags/ray_single_operator.py b/dev/dags/ray_single_operator.py index 6057515..cbca222 100644 --- a/dev/dags/ray_single_operator.py +++ b/dev/dags/ray_single_operator.py @@ -3,7 +3,7 @@ from airflow import DAG -from ray_provider.operators.ray import SubmitRayJob +from ray_provider.operators import SubmitRayJob CONN_ID = "ray_conn" RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml" diff --git a/dev/dags/ray_taskflow_example.py b/dev/dags/ray_taskflow_example.py index 5878cf0..9ccef6e 100644 --- a/dev/dags/ray_taskflow_example.py +++ b/dev/dags/ray_taskflow_example.py @@ -3,7 +3,7 @@ from airflow.decorators import dag, task -from ray_provider.decorators.ray import ray +from ray_provider.decorators import ray CONN_ID = "ray_conn" RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml" diff --git a/dev/dags/ray_taskflow_example_existing_cluster.py b/dev/dags/ray_taskflow_example_existing_cluster.py index 9160f50..a5515d2 100644 --- a/dev/dags/ray_taskflow_example_existing_cluster.py +++ b/dev/dags/ray_taskflow_example_existing_cluster.py @@ -3,9 +3,9 @@ from airflow.decorators import dag, task -from ray_provider.decorators.ray import ray +from ray_provider.decorators import ray -CONN_ID = "ray_job" +CONN_ID = "ray_conn" FOLDER_PATH = Path(__file__).parent / "ray_scripts" RAY_TASK_CONFIG = { "conn_id": CONN_ID, diff --git a/dev/dags/scripts/ray.yaml b/dev/dags/scripts/ray.yaml index 0c4313f..9af50e6 100644 --- a/dev/dags/scripts/ray.yaml +++ b/dev/dags/scripts/ray.yaml @@ -1,7 +1,8 @@ apiVersion: ray.io/v1 kind: RayCluster metadata: - name: raycluster-complete + name: tati-raycluster + # namespace: tati spec: rayVersion: "2.10.0" enableInTreeAutoscaling: true @@ -15,9 +16,11 @@ spec: labels: ray-node-type: head spec: + imagePullSecrets: + - name: my-registry-secret containers: - name: ray-head - image: rayproject/ray-ml:latest + image: rayproject/ray:2.20.0-aarch64 resources: limits: cpu: 1 @@ -50,9 +53,11 @@ spec: template: metadata: spec: + imagePullSecrets: + - name: my-registry-secret containers: - name: machine-learning - image: rayproject/ray-ml:latest + image: rayproject/ray:2.20.0-aarch64 resources: limits: cpu: 1 diff --git a/dev/dags/setup-teardown.py b/dev/dags/setup-teardown.py index c2ac712..69589b8 100644 --- a/dev/dags/setup-teardown.py +++ b/dev/dags/setup-teardown.py @@ -3,7 +3,7 @@ from airflow import DAG -from ray_provider.operators.ray import DeleteRayCluster, SetupRayCluster, SubmitRayJob +from ray_provider.operators import DeleteRayCluster, SetupRayCluster, SubmitRayJob CONN_ID = "ray_conn" RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml" @@ -42,3 +42,4 @@ # Create ray cluster and submit ray job setup_cluster.as_setup() >> submit_ray_job >> delete_cluster.as_teardown() setup_cluster >> delete_cluster + diff --git a/dev/docker-compose.override.yml b/dev/docker-compose.override.yml new file mode 100644 index 0000000..dea37d4 --- /dev/null +++ b/dev/docker-compose.override.yml @@ -0,0 +1,21 @@ +version: '3.8' + +services: + webserver: + #image: dev_c226a1/airflow:latest + networks: + - kind + + scheduler: + #image: dev_c226a1/airflow:latest + networks: + - kind + + triggerer: + #image: dev_c226a1/airflow:latest + networks: + - kind + +networks: + kind: + external: true diff --git a/dev/kind-config.yaml b/dev/kind-config.yaml new file mode 100644 index 0000000..f417105 --- /dev/null +++ b/dev/kind-config.yaml @@ -0,0 +1,21 @@ +kind: Cluster +name: local +apiVersion: kind.x-k8s.io/v1alpha4 +networking: + apiServerAddress: "0.0.0.0" + apiServerPort: 6443 +nodes: + - role: control-plane + #extraPortMappings: + # - containerPort: 30000 + # hostPort: 30000 + # listenAddress: "0.0.0.0" + # protocol: tcp +kubeadmConfigPatchesJSON6902: +- group: kubeadm.k8s.io + version: v1beta3 + kind: ClusterConfiguration + patch: | + - op: add + path: /apiServer/certSANs/- + value: host.docker.internal diff --git a/docs/getting_started/local_development_setup.rst b/docs/getting_started/local_development_setup.rst index 2352db1..9db263b 100644 --- a/docs/getting_started/local_development_setup.rst +++ b/docs/getting_started/local_development_setup.rst @@ -27,10 +27,48 @@ Install the following software: 1. **Create a Kind Cluster** +(a) If you plan to access the Kind Kubernetes cluster from Airflow using Astro CLI, use the following configuration file, +also available in ``dev/kind-config.yaml``, to create the Kind cluster: + +.. code-block:: + + kind: Cluster + name: local + apiVersion: kind.x-k8s.io/v1alpha4 + networking: + apiServerAddress: "0.0.0.0" + apiServerPort: 6443 + nodes: + - role: control-plane + kubeadmConfigPatchesJSON6902: + - group: kubeadm.k8s.io + version: v1beta3 + kind: ClusterConfiguration + patch: | + - op: add + path: /apiServer/certSANs/- + value: host.docker.internal + +Use the following command to create the Kind cluster: + +.. code-block:: bash + + kind create cluster --config kind-config.yaml + +If you don't do this, Astro CLI will have issues reaching the Kind Kubernetes cluster, raising exceptions similar to: + +.. code-block:: + + [2024-11-19, 14:52:39 UTC] {ray.py:606} ERROR - Standard Error: Error: Kubernetes cluster unreachable: Get "https://host.docker.internal:57034/version": tls: failed to verify certificate: x509: certificate is valid for kind-control-plane, kubernetes, kubernetes.default, kubernetes.default.svc, kubernetes.default.svc.cluster.local, localhost, not host.docker.internal + + +(b) Otherwise, if planning to access Kind from Airflow **outside of Astro CLi**, create a cluster using: + .. code-block:: bash kind create cluster --image=kindest/node:v1.26.0 + 2. **Deploy a KubeRay Operator** .. code-block:: bash @@ -82,7 +120,98 @@ Wait for the pods to reach the ``Running`` state 5. Access the Ray Dashboard -Visit http://127.0.0.1:8265 in your browser +Visit http://127.0.0.1:8265 in your browser. + + +Additional steps in MacOS +========================= + +When developing under MacOS (such as M1 instances), you may face some issues. The following steps describe how to overcome them. + +Requirements +------------ + +- `Docker Mac Net Connect `_ +- `MetalLB `_ + +1. Expose Kind Network to host +------------------------------ + +With Docker on Linux, you can send traffic directly to the LoadBalancer’s external IP if the IP space is within the Docker IP space. + +On MacOS, Docker does not expose the Docker network to the host". + +A workaround is to use docker-mac-net-connect: +https://github.com/chipmk/docker-mac-net-connect + +.. code-block:: bash + + # Install via Homebrew + $ brew install chipmk/tap/docker-mac-net-connect + + # Run the service and register it to launch at boot + $ sudo brew services start chipmk/tap/docker-mac-net-connect + +This will expose the Kind network to the host network seamlessly. + + +2. Enable the creation of LoadBalancers in Kind +----------------------------------------------- + +When attempting to run Ray in Kind from Airflow using Astro, you may face issues when attempting to spin up the Kubernetes LoadBalancer. +This will happen, particularly, if your DAGs create and tear down the Ray cluster, and are not using a pre-created cluster. + +A side-effect of this is that you will see the ``LoadBalancer`` hanging on the state ```` indefinitely. + +Example: + +.. code-block:: + $ kubectl get svc + + NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE + kubernetes ClusterIP 10.96.0.1 443/TCP 5d21h + my-raycluster-head-svc LoadBalancer 10.96.124.7 10001:31531/TCP,8265:30347/TCP,6379:31291/TCP,8080:30358/TCP,8000:32362/TCP 2m13s + +In a kind cluster, the Kubernetes control plane lacks direct integration with a cloud provider. +Since ``LoadBalancer`` services rely on cloud infrastructure to provision external IPs, they cannot natively work in kind without additional setup. + +By default: + +- The service type LoadBalancer won't provision an external IP. +- Services remain in a state until a workaround or external load balancer is introduced. + +You can use `MetalLB `_, a load balancer implementation for bare-metal Kubernetes clusters. Here's how to set it up: + +a) Install MetalLB: Apply the MetalLB manifests + +.. code-block:: bash + + kubectl apply -f https://raw.githubusercontent.com/metallb/metallb/v0.13.10/config/manifests/metallb-native.yaml + +b) Configure IP Address Pool: MetalLB requires a pool of IPs that it can assign. Create a ConfigMap with a range of available IPs: + +.. code-block:: bash + + cat <`_ that wraps the Astro CLI. It installs the necessary packages into your image to run the DAG locally. + 1. **Start Airflow Instance** .. code-block:: bash @@ -105,7 +235,77 @@ We have a `Makefile dict[str, Any]: "package-name": "astro-provider-ray", # Required "name": "Ray", # Required "description": "An integration between airflow and ray", # Required - "connection-types": [{"connection-type": "ray", "hook-class-name": "ray_provider.hooks.ray.RayHook"}], + "connection-types": [{"connection-type": "ray", "hook-class-name": "ray_provider.hooks.RayHook"}], "versions": [__version__], # Required } diff --git a/ray_provider/constants.py b/ray_provider/constants.py new file mode 100644 index 0000000..cdf8661 --- /dev/null +++ b/ray_provider/constants.py @@ -0,0 +1,4 @@ +from ray.job_submission import JobStatus + + +TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED} \ No newline at end of file diff --git a/ray_provider/decorators/ray.py b/ray_provider/decorators.py similarity index 67% rename from ray_provider/decorators/ray.py rename to ray_provider/decorators.py index 82ca474..8fd1cbb 100644 --- a/ray_provider/decorators/ray.py +++ b/ray_provider/decorators.py @@ -3,16 +3,17 @@ import inspect import os import re -import shutil +import tempfile import textwrap -from tempfile import mkdtemp +from datetime import timedelta +from pathlib import Path from typing import Any, Callable from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory from airflow.exceptions import AirflowException from airflow.utils.context import Context -from ray_provider.operators.ray import SubmitRayJob +from ray_provider.operators import SubmitRayJob class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob): @@ -28,10 +29,37 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob): """ custom_operator_name = "@task.ray" + _config: None | dict[str, Any] | Callable[..., dict[str, Any]] = None template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs") - def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: + def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None: + self._config = config + self.kwargs = kwargs + super().__init__( + conn_id="", + entrypoint="python script.py", + runtime_env={}, + **kwargs + ) + + def _build_config(self, context: Context) -> dict: + if isinstance(self._config, Callable): + return self._build_config_from_callable(context) + return self._config + + def _build_config_from_callable(self, context: Context) -> dict[str, Any]: + config_params = inspect.signature(self._config).parameters + + config_kwargs = {k: v for k, v in self.kwargs.items() if k in config_params and k != "context"} + + if "context" in config_params: + config_kwargs["context"] = context + + # Call config with the prepared arguments + return self._config(**config_kwargs) + + def _load_config(self, config: dict) -> None: self.conn_id: str = config.get("conn_id", "") self.is_decorated_function = False if "entrypoint" in config else True self.entrypoint: str = config.get("entrypoint", "python script.py") @@ -39,47 +67,24 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None: self.num_cpus: int | float = config.get("num_cpus", 1) self.num_gpus: int | float = config.get("num_gpus", 0) - self.memory: int | float = config.get("memory", None) - self.ray_resources: dict[str, Any] | None = config.get("resources", None) - self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None) + self.memory: int | float = config.get("memory") + self.ray_resources: dict[str, Any] | None = config.get("resources") + self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml") self.update_if_exists: bool = config.get("update_if_exists", False) self.kuberay_version: str = config.get("kuberay_version", "1.0.0") - self.gpu_device_plugin_yaml: str = config.get( - "gpu_device_plugin_yaml", - "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", - ) + self.gpu_device_plugin_yaml: str = config.get("gpu_device_plugin_yaml") self.fetch_logs: bool = config.get("fetch_logs", True) self.wait_for_completion: bool = config.get("wait_for_completion", True) - job_timeout_seconds: int = config.get("job_timeout_seconds", 600) + job_timeout_seconds = config.get("job_timeout_seconds", 600) + self.job_timeout_seconds: int = timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None self.poll_interval: int = config.get("poll_interval", 60) - self.xcom_task_key: str | None = config.get("xcom_task_key", None) - self.config = config + self.xcom_task_key: str | None = config.get("xcom_task_key") if not isinstance(self.num_cpus, (int, float)): raise TypeError("num_cpus should be an integer or float value") if not isinstance(self.num_gpus, (int, float)): raise TypeError("num_gpus should be an integer or float value") - super().__init__( - conn_id=self.conn_id, - entrypoint=self.entrypoint, - runtime_env=self.runtime_env, - num_cpus=self.num_cpus, - num_gpus=self.num_gpus, - memory=self.memory, - resources=self.ray_resources, - ray_cluster_yaml=self.ray_cluster_yaml, - update_if_exists=self.update_if_exists, - kuberay_version=self.kuberay_version, - gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, - fetch_logs=self.fetch_logs, - wait_for_completion=self.wait_for_completion, - job_timeout_seconds=job_timeout_seconds, - poll_interval=self.poll_interval, - xcom_task_key=self.xcom_task_key, - **kwargs, - ) - def execute(self, context: Context) -> Any: """ Execute the Ray task. @@ -88,15 +93,17 @@ def execute(self, context: Context) -> Any: :return: The result of the Ray job execution. :raises AirflowException: If job submission fails. """ - tmp_dir = None - try: + config = self._build_config(context) + self.log.info(f"Using the following config {config}") + self._load_config(config) + + with tempfile.TemporaryDirectory(prefix="ray_") as tmpdirname: + temp_dir = Path(tmpdirname) + if self.is_decorated_function: self.log.info( f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}" ) - # Create a temporary directory that won't be immediately deleted - temp_dir = mkdtemp(prefix="ray_") - script_filename = os.path.join(temp_dir, "script.py") # Get the Python source code and extract just the function body full_source = inspect.getsource(self.python_callable) @@ -110,6 +117,7 @@ def execute(self, context: Context) -> Any: call_str = f"{self.python_callable.__name__}({args_str}, {kwargs_str})" # Write the script with function definition and call + script_filename = os.path.join(temp_dir, "script.py") with open(script_filename, "w") as file: file.write(function_body) file.write(f"\n\n# Execute the function\n{call_str}\n") @@ -122,21 +130,21 @@ def execute(self, context: Context) -> Any: result = super().execute(context) return result - except Exception as e: - self.log.error(f"Failed during execution with error: {e}") - raise AirflowException("Job submission failed") from e - finally: - if tmp_dir and os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir) def _extract_function_body(self, source: str) -> str: """Extract the function, excluding only the ray.task decorator.""" lines = source.split("\n") + # TODO: This approach is extremely hacky. Review it. + # It feels a mistake to have a user-facing module named the same as the offical ray SDK + # In particular, the decorator is working in a very artificial way, where ray means two different things + # at the scope of the task definition (Astro Ray Provider decorator) and inside the decorated method (Ray SDK) # Find the line where the ray.task decorator is ray_task_line = next((i for i, line in enumerate(lines) if re.match(r"^\s*@ray\.task", line.strip())), -1) # Include everything except the ray.task decorator line body = "\n".join(lines[:ray_task_line] + lines[ray_task_line + 1 :]) + self.log.info("Ray job that is going to be executed: \m %s", body) + # Dedent the body return textwrap.dedent(body) @@ -146,6 +154,7 @@ class ray: def task( python_callable: Callable[..., Any] | None = None, multiple_outputs: bool | None = None, + config: dict[str, Any] | Callable[[], dict[str, Any]] | None = None, **kwargs: Any, ) -> TaskDecorator: """ @@ -153,12 +162,15 @@ def task( :param python_callable: The callable function to decorate. :param multiple_outputs: If True, will return multiple outputs. + :param config: A dictionary of configuration or a callable that returns a dictionary. :param kwargs: Additional keyword arguments. :return: The decorated task. """ + config = config or {} return task_decorator_factory( python_callable=python_callable, multiple_outputs=multiple_outputs, decorated_operator_class=_RayDecoratedOperator, + config=config, **kwargs, ) diff --git a/ray_provider/decorators/__init__.py b/ray_provider/decorators/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ray_provider/hooks/ray.py b/ray_provider/hooks.py similarity index 64% rename from ray_provider/hooks/ray.py rename to ray_provider/hooks.py index 2000ecf..196358d 100644 --- a/ray_provider/hooks/ray.py +++ b/ray_provider/hooks.py @@ -15,6 +15,8 @@ from kubernetes import client, config from ray.job_submission import JobStatus, JobSubmissionClient +from ray_provider.constants import TERMINAL_JOB_STATUSES + class RayHook(KubernetesHook): # type: ignore """ @@ -26,13 +28,82 @@ class RayHook(KubernetesHook): # type: ignore :param conn_id: The connection ID to use when fetching connection info. """ - conn_name_attr = "ray_conn_id" + conn_name_attr = "conn_id" default_conn_name = "ray_default" conn_type = "ray" hook_name = "Ray" DEFAULT_NAMESPACE = "default" + def __init__( + self, + conn_id: str = default_conn_name, + ) -> None: + """ + Initialize the RayHook. + + :param conn_id: The connection ID to use when fetching connection info. + """ + super().__init__(conn_id=conn_id) + self.conn_id = conn_id + + self.address = self._get_field("address") or os.getenv("RAY_ADDRESS") + self.log.debug(f"Ray cluster address is: {self.address}") + self.create_cluster_if_needed = False + self.cookies = self._get_field("cookies") + self.metadata = self._get_field("metadata") + self.headers = self._get_field("headers") + self.verify = self._get_field("verify") or False + self.ray_client_instance = None + + self.default_namespace = self.get_namespace() or self.DEFAULT_NAMESPACE + self.kubeconfig: str | None = None + self.in_cluster: bool | None = None + self.client_configuration = None + self.config_file = None + self.disable_verify_ssl = None + self.disable_tcp_keepalive = None + self._is_in_cluster: bool | None = None + + self.cluster_context = self._get_field("cluster_context") + self.kubeconfig_path = self._get_field("kube_config_path") + self.kubeconfig_content = self._get_field("kube_config") + self.ray_cluster_yaml = None + + self._setup_kubeconfig(self.kubeconfig_path, self.kubeconfig_content, self.cluster_context) + + @property # TODO: cached property + def namespace(self): + if self.ray_cluster_yaml is None: + return self.default_namespace + cluster_spec = self.load_yaml_content(self.ray_cluster_yaml) + return cluster_spec["metadata"].get("namespace") or self.default_namespace + + def test_connection(self): + job_client = self.ray_client(self.address) + + job_id = job_client.submit_job( + entrypoint="import ray; ray.init(); print(ray.cluster_resources())" + ) + self.log.info(f"Ray test connection: Submitted job with ID: {job_id}") + + job_completed = False + connection_attempt = 10 + while not job_completed and connection_attempt: + time.sleep(0.5) + job_status = job_client.get_job_status(job_id) + self.log.info(f"Ray test connection: Job {job_id} status {job_status}") + if job_status in TERMINAL_JOB_STATUSES: + job_completed = True + connection_attempt -= 1 + + + if job_status != JobStatus.SUCCEEDED: + return False, f"Ray test connection failed: Job {job_id} status {job_status}" + + return True, job_status + # TODO: check webserver logs + @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: """ @@ -71,41 +142,6 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: "disable_tcp_keepalive": BooleanField(lazy_gettext("Disable TCP keepalive")), } - def __init__( - self, - conn_id: str = default_conn_name, - ) -> None: - """ - Initialize the RayHook. - - :param conn_id: The connection ID to use when fetching connection info. - """ - super().__init__(conn_id=conn_id) - self.conn_id = conn_id - - self.address = self._get_field("address") or os.getenv("RAY_ADDRESS") - self.log.info(f"Ray cluster address is: {self.address}") - self.create_cluster_if_needed = False - self.cookies = self._get_field("cookies") - self.metadata = self._get_field("metadata") - self.headers = self._get_field("headers") - self.verify = self._get_field("verify") or False - self.ray_client_instance = None - - self.namespace = self.get_namespace() or self.DEFAULT_NAMESPACE - self.kubeconfig: str | None = None - self.in_cluster: bool | None = None - self.client_configuration = None - self.config_file = None - self.disable_verify_ssl = None - self.disable_tcp_keepalive = None - self._is_in_cluster: bool | None = None - - self.cluster_context = self._get_field("cluster_context") - self.kubeconfig_path = self._get_field("kube_config_path") - self.kubeconfig_content = self._get_field("kube_config") - - self._setup_kubeconfig(self.kubeconfig_path, self.kubeconfig_content, self.cluster_context) def _setup_kubeconfig( self, kubeconfig_path: str | None, kubeconfig_content: str | None, cluster_context: str | None @@ -255,6 +291,7 @@ def _is_port_open(self, host: str, port: int) -> bool: :param port: The port number to check. :return: True if the port is open, False otherwise. """ + self.log.info(f"_is_port_open: {host} {port}") with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(1) try: @@ -304,6 +341,7 @@ def _check_load_balancer_readiness(self, lb_details: dict[str, Any]) -> str | No ip: str | None = lb_details["ip"] hostname: str | None = lb_details["hostname"] + self.log.info(f"ports: {lb_details['ports']}") for port_info in lb_details["ports"]: port = port_info["port"] if ip and self._is_port_open(ip, port): @@ -313,6 +351,123 @@ def _check_load_balancer_readiness(self, lb_details: dict[str, Any]) -> str | No return None + def _get_node_ip(self) -> str: + """ + Retrieve the IP address of a Kubernetes node. + + :return: The IP address of a node in the Kubernetes cluster. + """ + # Example: Retrieve the first node's IP (adjust based on your cluster setup) + nodes = self.core_v1_client.list_node().items + self.log.info(f"Nodes: {nodes}") + for node in nodes: + self.log.info(f"Node address: {node.status.addresses}") + for address in node.status.addresses: + if address.type == "ExternalIP": + return address.address + + for node in nodes: + self.log.info(f"Node address: {node.status.addresses}") + for address in node.status.addresses: + if address.type == "InternalIP": + return address.address + + raise AirflowException("No valid node IP found in the cluster.") + + def _setup_node_port(self, name: str, namespace: str, context: dict) -> None: + """ + Set up the NodePort service and push URLs to XCom. + + :param name: The name of the Ray cluster. + :param namespace: The namespace where the cluster is deployed. + :param context: The Airflow task context. + """ + node_port_details: dict[str, Any] = self._wait_for_node_port_service( + service_name=f"{name}-head-svc", namespace=namespace + ) + + if node_port_details: + self.log.info(node_port_details) + + node_ports = node_port_details["node_ports"] + # Example: Assuming `node_ip` is provided as an environment variable or a known cluster node. + node_ip = self._get_node_ip() # Implement this method to return a valid node IP or DNS. + + for port in node_ports: + url = f"http://{node_ip}:{port['port']}" + context["task_instance"].xcom_push(key=port["name"], value=url) + self.log.info(f"Pushed URL to XCom: {url}") + else: + self.log.info("No NodePort URLs to push to XCom.") + + def _wait_for_node_port_service( + self, + service_name: str, + namespace: str = "default", + max_retries: int = 30, + retry_interval: int = 10, + ) -> dict[str, Any]: + """ + Wait for the NodePort service to be ready and return its details. + + :param service_name: The name of the NodePort service. + :param namespace: The namespace of the service. + :param max_retries: Maximum number of retries. + :param retry_interval: Interval between retries in seconds. + :return: A dictionary containing NodePort service details. + :raises AirflowException: If the service does not become ready within the specified retries. + """ + for attempt in range(1, max_retries + 1): + self.log.info(f"Attempt {attempt}: Checking NodePort service status...") + + try: + service: client.V1Service = self._get_service(service_name, namespace) + service_details: dict[str, Any] | None = self._get_node_port_details(service) + + if service_details: + self.log.info("NodePort service is ready.") + return service_details + + self.log.info("NodePort details not available yet. Retrying...") + except AirflowException: + self.log.info("Service is not available yet.") + + time.sleep(retry_interval) + + raise AirflowException(f"Service did not become ready after {max_retries} attempts") + + def _get_node_port_details(self, service: client.V1Service) -> dict[str, Any] | None: + """ + Extract NodePort details from the service. + + :param service: The Kubernetes service object. + :return: A dictionary containing NodePort details if available, None otherwise. + """ + node_ports = [] + for port in service.spec.ports: + if port.node_port: + node_ports.append({"name": port.name, "port": port.node_port}) + + if node_ports: + return {"node_ports": node_ports} + + return None + + def _check_node_port_connectivity(self, node_ports: list[dict[str, Any]]) -> bool: + """ + Check if the NodePort is reachable. + + :param node_ports: List of NodePort details. + :return: True if at least one NodePort is accessible, False otherwise. + """ + for port_info in node_ports: + # Replace with actual logic to test connectivity if needed. + self.log.info(f"Checking connectivity for NodePort {port_info['port']}") + # Example: Simulate readiness check. + if self._is_port_open("example-node-ip", port_info["port"]): + return True + return False + def _wait_for_load_balancer( self, service_name: str, @@ -331,11 +486,13 @@ def _wait_for_load_balancer( :raises AirflowException: If the LoadBalancer does not become ready within the specified retries. """ for attempt in range(1, max_retries + 1): - self.log.info(f"Attempt {attempt}: Checking LoadBalancer status...") + self.log.info(f"Attempt {attempt}: Checking LoadBalancer status {service_name} in {namespace}...") try: service: client.V1Service = self._get_service(service_name, namespace) + self.log.info(f"service: {service}") lb_details: dict[str, Any] | None = self._get_load_balancer_details(service) + self.log.info(f"lb_details: {lb_details}") if not lb_details: self.log.info("LoadBalancer details not available yet.") @@ -358,6 +515,41 @@ def _wait_for_load_balancer( raise AirflowException(f"LoadBalancer did not become ready after {max_retries} attempts") + def _get_load_balancer_details(self, service: client.V1Service) -> dict[str, Any] | None: + """ + Extract LoadBalancer details from the service. + + :param service: The Kubernetes service object. + :return: A dictionary containing LoadBalancer details if available, None otherwise. + """ + if service.status.load_balancer.ingress: + ingress: client.V1LoadBalancerIngress = service.status.load_balancer.ingress[0] + ip: str | None = ingress.ip + hostname: str | None = ingress.hostname + if ip or hostname: + ports: list[dict[str, Any]] = [{"name": port.name, "port": port.port} for port in service.spec.ports] + return {"ip": ip, "hostname": hostname, "ports": ports} + return None + + def _check_load_balancer_readiness(self, lb_details: dict[str, Any]) -> str | None: + """ + Check if the LoadBalancer is ready by testing port connectivity. + + :param lb_details: Dictionary containing LoadBalancer details. + :return: The working address (IP or hostname) if ready, None otherwise. + """ + ip: str | None = lb_details["ip"] + hostname: str | None = lb_details["hostname"] + + for port_info in lb_details["ports"]: + port = port_info["port"] + if ip and self._is_port_open(ip, port): + return ip + if hostname and self._is_port_open(hostname, port): + return hostname + + return None + def _validate_yaml_file(self, yaml_file: str) -> None: """ Validate the existence and format of the YAML file. @@ -398,21 +590,48 @@ def _create_or_update_cluster( :param cluster_spec: The specification of the Ray cluster. :raises AirflowException: If there's an error accessing or creating the Ray cluster. """ - try: - self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) - if update_if_exists: - self.log.info(f"Updating existing Ray cluster: {name}") - self.custom_object_client.patch_namespaced_custom_object( - group=group, version=version, namespace=namespace, plural=plural, name=name, body=cluster_spec - ) + """self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) + if update_if_exists: + self.log.info(f"Updating existing Ray cluster: {name}") + self.custom_object_client.patch_namespaced_custom_object( + group=group, version=version, namespace=namespace, plural=plural, name=name, body=cluster_spec + ) + except client.exceptions.ApiException as e: if e.status == 404: - self.log.info(f"Creating new Ray cluster: {name}") - self.create_custom_object( - group=group, version=version, namespace=namespace, plural=plural, body=cluster_spec - ) + """ + + self.log.info(f"Creating new Ray cluster: {name}") + + response = self.create_custom_object( + group=group, version=version, namespace=namespace, plural=plural, body=cluster_spec + ) + self.log.info(f"Resource created. Response: {response}") + + start_time = time.time() + wait_timeout = 300 + poll_interval = 5 + + while time.time() - start_time < wait_timeout: + try: + resource = self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) + except client.exceptions.ApiException as e: + self.log.warning(f"Error fetching resource status: {e}") + else: + status = resource.get("status", {}) + self.log.info(f"Current status: {status}") + if status.get("state") == "ready": + self.log.info(f"Resource {name} of group {group} is now ready.") + return status + + time.sleep(poll_interval) + + raise TimeoutError(f"Resource {name} of group {group} did not reach the desired state within {wait_timeout} seconds.") + + """ else: raise AirflowException(f"Error accessing Ray cluster '{name}': {e}") + """ def _setup_gpu_driver(self, gpu_device_plugin_yaml: str) -> None: """ @@ -420,12 +639,14 @@ def _setup_gpu_driver(self, gpu_device_plugin_yaml: str) -> None: :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. """ - gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) - gpu_driver_name = gpu_driver["metadata"]["name"] + self.log.info("Trying to setup gpu_device_plugin_yaml %s", gpu_device_plugin_yaml) + if gpu_device_plugin_yaml: + gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) + gpu_driver_name = gpu_driver["metadata"]["name"] - if not self.get_daemon_set(gpu_driver_name): - self.log.info("Creating DaemonSet for NVIDIA device plugin...") - self.create_daemon_set(gpu_driver_name, gpu_driver) + if not self.get_daemon_set(gpu_driver_name): + self.log.info("Creating DaemonSet for NVIDIA device plugin...") + self.create_daemon_set(gpu_driver_name, gpu_driver) def _setup_load_balancer(self, name: str, namespace: str, context: Context) -> None: """ @@ -464,24 +685,26 @@ def setup_ray_cluster( :param update_if_exists: Whether to update the cluster if it already exists. :raises AirflowException: If there's an error setting up the Ray cluster. """ - try: - self._validate_yaml_file(ray_cluster_yaml) + #try: + self._validate_yaml_file(ray_cluster_yaml) + self.ray_cluster_yaml = ray_cluster_yaml - self.log.info("::group::Add KubeRay operator") - self.install_kuberay_operator(version=kuberay_version) - self.log.info("::endgroup::") + self.log.info("::group:: (Setup 1/3) Add KubeRay operator") + self.install_kuberay_operator(version=kuberay_version) + self.log.info("::endgroup::") - self.log.info("::group::Create Ray Cluster") - self.log.info("Loading yaml content for Ray cluster CRD...") - cluster_spec = self.load_yaml_content(ray_cluster_yaml) + self.log.info("::group:: (Setup 2/3) Create Ray Cluster") + self.log.info("Loading yaml content for Ray cluster CRD...") + cluster_spec = self.load_yaml_content(ray_cluster_yaml) - kind = cluster_spec["kind"] - plural = f"{kind.lower()}s" if kind == "RayCluster" else kind - name = cluster_spec["metadata"]["name"] - namespace = self.namespace - api_version = cluster_spec["apiVersion"] - group, version = api_version.split("/") if "/" in api_version else ("", api_version) + kind = cluster_spec["kind"] + plural = f"{kind.lower()}s" if kind == "RayCluster" else kind + name = cluster_spec["metadata"]["name"] + namespace = cluster_spec["metadata"].get("namespace") or self.namespace + api_version = cluster_spec["apiVersion"] + group, version = api_version.split("/") if "/" in api_version else ("", api_version) + try: self._create_or_update_cluster( update_if_exists=update_if_exists, group=group, @@ -491,17 +714,25 @@ def setup_ray_cluster( namespace=namespace, cluster_spec=cluster_spec, ) - self.log.info("::endgroup::") + except TimeoutError as e: + self._delete_ray_cluster_crd(ray_cluster_yaml) + raise AirflowException(e) + self.log.info("::endgroup::") + + #self._setup_gpu_driver(gpu_device_plugin_yaml=gpu_device_plugin_yaml) + + #self.log.info("::group:: (Step 3/3) Setup Node Port service") + #self._setup_node_port(name, namespace, context) + #self.log.info("::endgroup::") - self._setup_gpu_driver(gpu_device_plugin_yaml=gpu_device_plugin_yaml) + self.log.info("::group:: (Setup 3/3) Setup Load Balancer service") + self._setup_load_balancer(name, namespace, context) + self.log.info("::endgroup::") - self.log.info("::group::Setup Load Balancer service") - self._setup_load_balancer(name, namespace, context) - self.log.info("::endgroup::") + #except Exception as e: + # self.log.error(f"Error setting up Ray cluster: {e}") + # raise AirflowException(f"Failed to set up Ray cluster: {e}") - except Exception as e: - self.log.error(f"Error setting up Ray cluster: {e}") - raise AirflowException(f"Failed to set up Ray cluster: {e}") def _delete_ray_cluster_crd(self, ray_cluster_yaml: str) -> None: """ @@ -510,6 +741,7 @@ def _delete_ray_cluster_crd(self, ray_cluster_yaml: str) -> None: :param ray_cluster_yaml: Path to the YAML file defining the Ray cluster. :raises AirflowException: If there's an error deleting the Ray cluster. """ + self.log.info("Attempting to delete a ray cluster...") self.log.info("Loading yaml content for Ray cluster CRD...") cluster_spec = self.load_yaml_content(ray_cluster_yaml) @@ -521,14 +753,16 @@ def _delete_ray_cluster_crd(self, ray_cluster_yaml: str) -> None: group, version = api_version.split("/") if "/" in api_version else ("", api_version) try: - if self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace): - self.delete_custom_object(group=group, version=version, name=name, namespace=namespace, plural=plural) - self.log.info(f"Deleted Ray cluster: {name}") - else: - self.log.info(f"Ray cluster: {name} not found. Skipping the delete step.") + self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) except client.exceptions.ApiException as e: - if e.status != 404: - raise AirflowException(f"Error deleting Ray cluster '{name}': {e}") + if e.status == 404: + self.log.info(f"Ray cluster: {name} not found. Skipping the delete step.") + else: + self.log.exception(f"Issue retrieving Ray cluster: {name}. Unable to delete it.") + else: + self.delete_custom_object(group=group, version=version, name=name, namespace=namespace, plural=plural) + self.log.info(f"Deleted Ray cluster: {name}") + def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) -> None: """ @@ -538,10 +772,11 @@ def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. Defaults to NVIDIA's plugin :raises AirflowException: If there's an error deleting the Ray cluster. """ - try: - self._validate_yaml_file(ray_cluster_yaml) + #try: + self._validate_yaml_file(ray_cluster_yaml) - """Delete the NVIDIA GPU device plugin DaemonSet if it exists.""" + if gpu_device_plugin_yaml: + #Delete the NVIDIA GPU device plugin DaemonSet if it exists. gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) gpu_driver_name = gpu_driver["metadata"]["name"] @@ -549,13 +784,18 @@ def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) self.log.info("Deleting DaemonSet for NVIDIA device plugin...") self.delete_daemon_set(gpu_driver_name) - self.log.info("::group:: Delete Ray Cluster") - self._delete_ray_cluster_crd(ray_cluster_yaml=ray_cluster_yaml) - self.log.info("::endgroup::") - self.uninstall_kuberay_operator() - except Exception as e: - self.log.error(f"Error deleting Ray cluster: {e}") - raise AirflowException(f"Failed to delete Ray cluster: {e}") + self.log.info("::group:: Delete Ray Cluster") + self._delete_ray_cluster_crd(ray_cluster_yaml=ray_cluster_yaml) + self.log.info("::endgroup::") + + # TODO: review this previous behaviour of the code + # It can be problematic for us to uninstall the Kuberay operator that might have been previously installed: + self.log.info("::group:: Delete Kuberay operator") + self.uninstall_kuberay_operator() + self.log.info("::endgroup::") + #except Exception as e: + # self.log.error(f"Error deleting Ray cluster: {e}") + # raise AirflowException(f"Failed to delete Ray cluster: {e}") def _run_bash_command(self, command: str, env: dict[str, str] | None = None) -> tuple[str | None, str | None]: """ @@ -571,14 +811,17 @@ def _run_bash_command(self, command: str, env: dict[str, str] | None = None) -> try: result = subprocess.run(command, shell=True, check=True, text=True, capture_output=True, env=custom_env) - self.log.info("Standard Output: %s", result.stdout) - self.log.info("Standard Error: %s", result.stderr) + + if result.stderr: + self.log.info("Standard Error: %s", result.stderr) + else: + self.log.info("Standard Output: %s", result.stdout) return result.stdout, result.stderr except subprocess.CalledProcessError as e: - self.log.error("An error occurred while executing the command: %s", e) - self.log.error("Return code: %s", e.returncode) - self.log.error("Standard Output: %s", e.stdout) - self.log.error("Standard Error: %s", e.stderr) + self.log.warning("An error occurred while executing the command: %s", e) + self.log.warning("Return code: %s", e.returncode) + self.log.warning("Standard Output: %s", e.stdout) + self.log.warning("Standard Error: %s", e.stderr) return None, None def install_kuberay_operator( @@ -597,6 +840,7 @@ def install_kuberay_operator( helm upgrade --install kuberay-operator kuberay/kuberay-operator \ --version {version} --create-namespace --namespace {self.namespace} --kubeconfig {self.kubeconfig} """ + self.log.info(helm_command) result = self._run_bash_command(helm_command, env) self.log.info(result) return result @@ -621,6 +865,7 @@ def get_daemon_set(self, name: str) -> client.V1DaemonSet | None: :param name: The name of the DaemonSet. :return: The DaemonSet resource if found, None otherwise. """ + self.log.warning(f"Trying to find DaemonSet not found: {name}") try: api_response = self.apps_v1_client.read_namespaced_daemon_set(name, self.namespace) self.log.info(f"DaemonSet {api_response.metadata.name} retrieved.") @@ -641,6 +886,7 @@ def create_daemon_set(self, name: str, body: dict[str, Any]) -> client.V1DaemonS :param body: The body of the DaemonSet for the create action. :return: The created DaemonSet resource if successful, None otherwise. """ + self.log.warning("Trying to create create_daemon_set %s", name) if not body: self.log.error("Body must be provided for create action.") return None @@ -660,6 +906,7 @@ def delete_daemon_set(self, name: str) -> client.V1Status | None: :param name: The name of the DaemonSet. :return: The status of the delete operation if successful, None otherwise. """ + self.log.info("Trying to delete_daemon_set %s", name) try: delete_response = self.apps_v1_client.delete_namespaced_daemon_set(name=name, namespace=self.namespace) self.log.info(f"DaemonSet {name} deleted.") diff --git a/ray_provider/hooks/__init__.py b/ray_provider/hooks/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ray_provider/operators/ray.py b/ray_provider/operators.py similarity index 77% rename from ray_provider/operators/ray.py rename to ray_provider/operators.py index 02b6a73..2b41c53 100644 --- a/ray_provider/operators/ray.py +++ b/ray_provider/operators.py @@ -10,8 +10,8 @@ from airflow.utils.context import Context from ray.job_submission import JobStatus -from ray_provider.hooks.ray import RayHook -from ray_provider.triggers.ray import RayJobTrigger +from ray_provider.hooks import RayHook +from ray_provider.triggers import RayJobTrigger class SetupRayCluster(BaseOperator): @@ -30,7 +30,8 @@ def __init__( conn_id: str, ray_cluster_yaml: str, kuberay_version: str = "1.0.0", - gpu_device_plugin_yaml: str = "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", + #gpu_device_plugin_yaml: str = "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml", + gpu_device_plugin_yaml: str = "", update_if_exists: bool = False, **kwargs: Any, ) -> None: @@ -52,6 +53,7 @@ def execute(self, context: Context) -> None: :param context: The context in which the operator is being executed. """ + self.log.info("Trying to setup ray cluster") self.hook.setup_ray_cluster( context=context, ray_cluster_yaml=self.ray_cluster_yaml, @@ -59,6 +61,7 @@ def execute(self, context: Context) -> None: gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, update_if_exists=self.update_if_exists, ) + self.log.info("Finished setting up the ray cluster") class DeleteRayCluster(BaseOperator): @@ -261,55 +264,48 @@ def execute(self, context: Context) -> str: :raises AirflowException: If the job fails, is cancelled, or reaches an unexpected state. """ - try: - self._setup_cluster(context=context) - - self.dashboard_url = self._get_dashboard_url(context) - - self.job_id = self.hook.submit_ray_job( - dashboard_url=self.dashboard_url, - entrypoint=self.entrypoint, - runtime_env=self.runtime_env, - entrypoint_num_cpus=self.num_cpus, - entrypoint_num_gpus=self.num_gpus, - entrypoint_memory=self.memory, - entrypoint_resources=self.ray_resources, - ) - self.log.info(f"Ray job submitted with id: {self.job_id}") - - if self.wait_for_completion: - current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) - self.log.info(f"Current job status for {self.job_id} is: {current_status}") - - if current_status not in self.terminal_states: - self.log.info("Deferring the polling to RayJobTrigger...") - self.defer( - trigger=RayJobTrigger( - job_id=self.job_id, - conn_id=self.conn_id, - xcom_dashboard_url=self.dashboard_url, - ray_cluster_yaml=self.ray_cluster_yaml, - gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, - poll_interval=self.poll_interval, - fetch_logs=self.fetch_logs, - ), - method_name="execute_complete", - timeout=self.job_timeout_seconds, - ) - elif current_status == JobStatus.SUCCEEDED: - self.log.info("Job %s completed successfully", self.job_id) - elif current_status == JobStatus.FAILED: - raise AirflowException(f"Job failed:\n{self.job_id}") - elif current_status == JobStatus.STOPPED: - raise AirflowException(f"Job was cancelled:\n{self.job_id}") - else: - raise AirflowException( - f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`" - ) - return self.job_id - except Exception as e: - self._delete_cluster() - raise AirflowException(f"SubmitRayJob operator failed due to {e}. Cleaning up resources...") + #try: + self.log.info("::group:: (SubmitJob 1/5) Setup Cluster") + self._setup_cluster(context=context) + self.log.info("::endgroup::") + + self.log.info("::group:: (SubmitJob 2/5) Identify Dashboard URL") + self.dashboard_url = self._get_dashboard_url(context) + self.log.info("::endgroup::") + + self.log.info("::group:: (SubmitJob 3/5) Submit job") + self.log.info(f"Ray job submitted with id: {self.job_id}") + self.job_id = self.hook.submit_ray_job( + dashboard_url=self.dashboard_url, + entrypoint=self.entrypoint, + runtime_env=self.runtime_env, + entrypoint_num_cpus=self.num_cpus, + entrypoint_num_gpus=self.num_gpus, + entrypoint_memory=self.memory, + entrypoint_resources=self.ray_resources, + ) + self.log.info("::endgroup::") + + self.log.info("::group:: (SubmitJob 4/5) Wait for completion") + if self.wait_for_completion: + current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) + self.log.info(f"Current job status for {self.job_id} is: {current_status}") + + if current_status not in self.terminal_states: + self.log.info("Deferring the polling to RayJobTrigger...") + self.defer( + trigger=RayJobTrigger( + job_id=self.job_id, + conn_id=self.conn_id, + xcom_dashboard_url=self.dashboard_url, + ray_cluster_yaml=self.ray_cluster_yaml, + gpu_device_plugin_yaml=self.gpu_device_plugin_yaml, + poll_interval=self.poll_interval, + fetch_logs=self.fetch_logs, + ), + method_name="execute_complete", + timeout=self.job_timeout_seconds, + ) def execute_complete(self, context: Context, event: dict[str, Any]) -> None: """ @@ -322,13 +318,25 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: :param event: The event containing the job execution result. :raises AirflowException: If the job execution fails, is cancelled, or reaches an unexpected state. """ - try: - if event["status"] in [JobStatus.STOPPED, JobStatus.FAILED]: - self.log.info(f"Ray job {self.job_id} execution not completed successfully...") - raise AirflowException(f"Job {self.job_id} {event['status'].lower()}: {event['message']}") - elif event["status"] == JobStatus.SUCCEEDED: - self.log.info(f"Ray job {self.job_id} execution succeeded.") + self.log.info("::endgroup::") + self.log.info("::group:: (SubmitJob 5/5) Execution completed") + + self._delete_cluster() + + job_status = event["status"] + if job_status == JobStatus.SUCCEEDED: + self.log.info("Job %s completed successfully", self.job_id) + return self.job_id + else: + self.log.info(f"Ray job {self.job_id} execution not completed successfully...") + if job_status in (JobStatus.FAILED, JobStatus.STOPPED): + msg = f"Job {self.job_id} {job_status.lower()}: {event['message']}" else: - raise AirflowException(f"Unexpected event status for job {self.job_id}: {event['status']}") - finally: - self._delete_cluster() + msg = f"Encountered unexpected state `{job_status}` for job_id `{self.job_id}`" + + self.log.info("::endgroup::") + + raise AirflowException(msg) + + + diff --git a/ray_provider/operators/__init__.py b/ray_provider/operators/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ray_provider/triggers/ray.py b/ray_provider/triggers.py similarity index 70% rename from ray_provider/triggers/ray.py rename to ray_provider/triggers.py index 745c74f..0b5be0f 100644 --- a/ray_provider/triggers/ray.py +++ b/ray_provider/triggers.py @@ -7,7 +7,7 @@ from airflow.triggers.base import BaseTrigger, TriggerEvent from ray.job_submission import JobStatus -from ray_provider.hooks.ray import RayHook +from ray_provider.hooks import RayHook class RayJobTrigger(BaseTrigger): @@ -51,7 +51,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: :return: A tuple containing the fully qualified class name and a dictionary of its parameters. """ return ( - "ray_provider.triggers.ray.RayJobTrigger", + "ray_provider.triggers.RayJobTrigger", { "job_id": self.job_id, "conn_id": self.conn_id, @@ -81,18 +81,15 @@ async def cleanup(self) -> None: resources are not deleted. """ - try: - if self.ray_cluster_yaml: - self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}") - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml - ) - self.log.info("Ray cluster deletion process completed") - else: - self.log.info("No Ray cluster YAML provided, skipping cluster deletion") - except Exception as e: - self.log.error(f"Unexpected error during cleanup: {str(e)}") + if self.ray_cluster_yaml: + self.log.info(f"Attempting to delete Ray cluster using YAML: {self.ray_cluster_yaml}") + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, self.hook.delete_ray_cluster, self.ray_cluster_yaml, self.gpu_device_plugin_yaml + ) + self.log.info("Ray cluster deletion process completed") + else: + self.log.info("No Ray cluster YAML provided, skipping cluster deletion") async def _poll_status(self) -> None: while not self._is_terminal_state(): @@ -118,28 +115,34 @@ async def run(self) -> AsyncIterator[TriggerEvent]: :yield: TriggerEvent containing the status, message, and job ID related to the job. """ - try: - self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...") - - tasks = [self._poll_status()] - if self.fetch_logs: - tasks.append(self._stream_logs()) - - await asyncio.gather(*tasks) - - completed_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) - self.log.info(f"Status of completed job {self.job_id} is: {completed_status}") - yield TriggerEvent( - { - "status": completed_status, - "message": f"Job {self.job_id} completed with status {completed_status}", - "job_id": self.job_id, - } - ) - except Exception as e: - self.log.error(f"Error occurred: {str(e)}") - await self.cleanup() - yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id}) + # This is used indirectly when the Ray decorator is used. + # If not imported, DAGs that used the Ray decorator fail when triggered + + self.log.info(f"::group:: Trigger 1/2: Checking the job status") + self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...") + + tasks = [self._poll_status()] + if self.fetch_logs: + tasks.append(self._stream_logs()) + self.log.info(f"::endgroup::") + await asyncio.gather(*tasks) + + self.log.info(f"::group:: Trigger 2/2: Job reached a terminal state") + completed_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id) + self.log.info(f"Status of completed job {self.job_id} is: {completed_status}") + self.log.info(f"::endgroup::") + + yield TriggerEvent( + { + "status": completed_status, + "message": f"Job {self.job_id} completed with status {completed_status}", + "job_id": self.job_id, + } + ) + #except Exception as e: + # self.log.error(f"Error occurred: {str(e)}") + # await self.cleanup() + # yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id}) def _is_terminal_state(self) -> bool: """ diff --git a/ray_provider/triggers/__init__.py b/ray_provider/triggers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/decorators/__init__.py b/tests/decorators/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/hooks/__init__.py b/tests/hooks/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/decorators/test_ray_decorators.py b/tests/test_ray_decorators.py similarity index 92% rename from tests/decorators/test_ray_decorators.py rename to tests/test_ray_decorators.py index 109f748..115a4ca 100644 --- a/tests/decorators/test_ray_decorators.py +++ b/tests/test_ray_decorators.py @@ -5,7 +5,8 @@ from airflow.exceptions import AirflowException from airflow.utils.context import Context -from ray_provider.decorators.ray import _RayDecoratedOperator, ray +from ray_provider.decorators import _RayDecoratedOperator +from ray_provider.decorators import ray as ray_decorator class TestRayDecoratedOperator: @@ -81,7 +82,7 @@ def dummy_callable(): _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) @patch.object(_RayDecoratedOperator, "_extract_function_body") - @patch("ray_provider.decorators.ray.SubmitRayJob.execute") + @patch("ray_provider.decorators.SubmitRayJob.execute") def test_execute_decorated_function(self, mock_super_execute, mock_extract_function_body): config = { "runtime_env": {"pip": ["ray"]}, @@ -101,7 +102,7 @@ def dummy_callable(): assert operator.entrypoint == "python script.py" assert "working_dir" in operator.runtime_env - @patch("ray_provider.decorators.ray.SubmitRayJob.execute") + @patch("ray_provider.decorators.SubmitRayJob.execute") def test_execute_with_entrypoint(self, mock_super_execute): config = { "entrypoint": "python my_script.py", @@ -119,7 +120,7 @@ def dummy_callable(): assert result == "success" assert operator.entrypoint == "python my_script.py" - @patch("ray_provider.decorators.ray.SubmitRayJob.execute") + @patch("ray_provider.decorators.SubmitRayJob.execute") def test_execute_failure(self, mock_super_execute): config = {} @@ -136,14 +137,14 @@ def dummy_callable(): def test_extract_function_body(self): config = {} - @ray.task() + @ray_decorator.task() def dummy_callable(): return "dummy" operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) function_body = operator._extract_function_body( - """@ray.task() + """@ray_decorator.task() def dummy_callable(): return "dummy" """ @@ -158,7 +159,7 @@ def dummy_callable(): class TestRayTaskDecorator: def test_ray_task_decorator(self): - @ray.task() + @ray_decorator.task() def dummy_function(): return "dummy" @@ -167,7 +168,7 @@ def dummy_function(): assert dummy_function.operator_class == _RayDecoratedOperator def test_ray_task_decorator_with_multiple_outputs(self): - @ray.task(multiple_outputs=True) + @ray_decorator.task(multiple_outputs=True) def dummy_function(): return {"key": "value"} @@ -182,7 +183,7 @@ def test_ray_task_decorator_with_config(self): "memory": 1024, } - @ray.task(**config) + @ray_decorator.task(**config) def dummy_function(): return "dummy" diff --git a/tests/hooks/test_ray_hooks.py b/tests/test_ray_hooks.py similarity index 79% rename from tests/hooks/test_ray_hooks.py rename to tests/test_ray_hooks.py index 95c787d..cfa33b7 100644 --- a/tests/hooks/test_ray_hooks.py +++ b/tests/test_ray_hooks.py @@ -8,19 +8,19 @@ from kubernetes.client.exceptions import ApiException from ray.job_submission import JobStatus -from ray_provider.hooks.ray import RayHook +from ray_provider.hooks import RayHook class TestRayHook: @pytest.fixture def ray_hook(self): - with patch("ray_provider.hooks.ray.KubernetesHook.get_connection") as mock_get_connection: + with patch("ray_provider.hooks.KubernetesHook.get_connection") as mock_get_connection: mock_connection = Mock() mock_connection.extra_dejson = {"kube_config_path": None, "kube_config": None, "cluster_context": None} mock_get_connection.return_value = mock_connection - with patch("ray_provider.hooks.ray.KubernetesHook.__init__", return_value=None): + with patch("ray_provider.hooks.KubernetesHook.__init__", return_value=None): hook = RayHook(conn_id="test_conn") # Manually set the necessary attributes hook.namespace = "default" @@ -40,7 +40,7 @@ def test_get_connection_form_widgets(self): assert "kube_config_path" in widgets assert "namespace" in widgets - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.JobSubmissionClient") def test_ray_client(self, mock_job_client, ray_hook): mock_job_client.return_value = MagicMock() client = ray_hook.ray_client() @@ -54,16 +54,16 @@ def test_ray_client(self, mock_job_client, ray_hook): verify=ray_hook.verify, ) - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.JobSubmissionClient") def test_submit_ray_job(self, mock_job_client, ray_hook): mock_client_instance = mock_job_client.return_value mock_client_instance.submit_job.return_value = "test_job_id" job_id = ray_hook.submit_ray_job(dashboard_url="http://example.com", entrypoint="test_entry") assert job_id == "test_job_id" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.config.load_kube_config") def test_setup_kubeconfig_path(self, mock_load_kube_config, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) @@ -74,9 +74,9 @@ def test_setup_kubeconfig_path(self, mock_load_kube_config, mock_kubernetes_init assert hook.kubeconfig == "/tmp/fake_kubeconfig" mock_load_kube_config.assert_called_once_with(config_file="/tmp/fake_kubeconfig", context="test_context") - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.config.load_kube_config") @patch("tempfile.NamedTemporaryFile") def test_setup_kubeconfig_content( self, mock_tempfile, mock_load_kube_config, mock_kubernetes_init, mock_get_connection @@ -95,8 +95,8 @@ def test_setup_kubeconfig_content( mock_tempfile.return_value.__enter__.return_value.write.assert_called_once_with(kubeconfig_content.encode()) mock_load_kube_config.assert_called_once_with(config_file="/tmp/fake_kubeconfig", context="test_context") - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") def test_setup_kubeconfig_invalid_config(self, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) @@ -110,8 +110,8 @@ def test_setup_kubeconfig_invalid_config(self, mock_kubernetes_init, mock_get_co "kube_config are mutually exclusive. You can only use one option at a time." ) - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.JobSubmissionClient") def test_delete_ray_job(self, mock_job_client, mock_get_connection): mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) mock_client_instance = mock_job_client.return_value @@ -120,8 +120,8 @@ def test_delete_ray_job(self, mock_job_client, mock_get_connection): result = hook.delete_ray_job("http://example.com", job_id="test_job_id") assert result == "deleted" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.JobSubmissionClient") def test_get_ray_job_status(self, mock_job_client, mock_get_connection): mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) mock_client_instance = mock_job_client.return_value @@ -130,8 +130,8 @@ def test_get_ray_job_status(self, mock_job_client, mock_get_connection): status = hook.get_ray_job_status("http://example.com", "test_job_id") assert status == JobStatus.SUCCEEDED - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.JobSubmissionClient") def test_get_ray_job_logs(self, mock_job_client, mock_get_connection): mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) mock_client_instance = mock_job_client.return_value @@ -154,8 +154,8 @@ def test_get_ray_job_logs(self, mock_job_client, mock_get_connection): ) mock_client_instance.get_job_logs.assert_called_once_with(job_id=job_id) - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.requests.get") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.requests.get") @patch("builtins.open", new_callable=mock_open, read_data="key: value\n") def test_load_yaml_content(self, mock_open, mock_requests, mock_get_connection): mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) @@ -203,8 +203,8 @@ def test_validate_yaml_file_not_exists(self, mock_isfile, ray_hook): assert "The specified YAML file does not exist" in str(exc_info.value) mock_isfile.assert_called_once_with("non_existent_file.yaml") - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.socket.socket") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.socket.socket") def test_is_port_open(self, mock_socket, mock_get_connection): mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) mock_socket_instance = mock_socket.return_value @@ -215,7 +215,7 @@ def test_is_port_open(self, mock_socket, mock_get_connection): result = hook._is_port_open("localhost", 8080) assert result is True - @patch("ray_provider.hooks.ray.RayHook.core_v1_client") + @patch("ray_provider.hooks.RayHook.core_v1_client") def test_get_service_success(self, mock_core_v1_client, ray_hook): mock_service = Mock(spec=client.V1Service) mock_core_v1_client.read_namespaced_service.return_value = mock_service @@ -225,7 +225,7 @@ def test_get_service_success(self, mock_core_v1_client, ray_hook): assert service == mock_service mock_core_v1_client.read_namespaced_service.assert_called_once_with("test-service", "default") - @patch("ray_provider.hooks.ray.RayHook.core_v1_client") + @patch("ray_provider.hooks.RayHook.core_v1_client") def test_get_service_not_found(self, mock_core_v1_client, ray_hook): mock_core_v1_client.read_namespaced_service.side_effect = client.exceptions.ApiException(status=404) @@ -285,8 +285,8 @@ def test_get_load_balancer_details_no_ip_or_hostname(self, ray_hook): assert lb_details is None - @patch("ray_provider.hooks.ray.RayHook.log") - @patch("ray_provider.hooks.ray.subprocess.run") + @patch("ray_provider.hooks.RayHook.log") + @patch("ray_provider.hooks.subprocess.run") def test_run_bash_command_exception(self, mock_subprocess_run, mock_log, ray_hook): mock_subprocess_run.side_effect = subprocess.CalledProcessError( returncode=1, cmd="test command", output="test output", stderr="test error" @@ -313,9 +313,9 @@ def test_run_bash_command_exception(self, mock_subprocess_run, mock_log, ray_hoo env=ray_hook._run_bash_command.__globals__["os"].environ.copy(), ) - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.subprocess.run") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.subprocess.run") def test_install_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) @@ -327,9 +327,9 @@ def test_install_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_ini assert "install output" in stdout assert stderr == "" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.subprocess.run") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.subprocess.run") def test_uninstall_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) @@ -341,9 +341,9 @@ def test_uninstall_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_i assert "uninstall output" in stdout assert stderr == "" - @patch("ray_provider.hooks.ray.RayHook._get_service") - @patch("ray_provider.hooks.ray.RayHook._get_load_balancer_details") - @patch("ray_provider.hooks.ray.RayHook._check_load_balancer_readiness") + @patch("ray_provider.hooks.RayHook._get_service") + @patch("ray_provider.hooks.RayHook._get_load_balancer_details") + @patch("ray_provider.hooks.RayHook._check_load_balancer_readiness") def test_wait_for_load_balancer_success( self, mock_check_readiness, mock_get_lb_details, mock_get_service, ray_hook ): @@ -370,9 +370,9 @@ def test_wait_for_load_balancer_success( mock_get_lb_details.assert_called_once_with(mock_service) mock_check_readiness.assert_called_once() - @patch("ray_provider.hooks.ray.RayHook._get_service") - @patch("ray_provider.hooks.ray.RayHook._get_load_balancer_details") - @patch("ray_provider.hooks.ray.RayHook._is_port_open") + @patch("ray_provider.hooks.RayHook._get_service") + @patch("ray_provider.hooks.RayHook._get_load_balancer_details") + @patch("ray_provider.hooks.RayHook._is_port_open") def test_wait_for_load_balancer_timeout(self, mock_is_port_open, mock_get_lb_details, mock_get_service, ray_hook): mock_service = Mock(spec=client.V1Service) mock_get_service.return_value = mock_service @@ -390,7 +390,7 @@ def test_wait_for_load_balancer_timeout(self, mock_is_port_open, mock_get_lb_det assert "LoadBalancer did not become ready after 2 attempts" in str(exc_info.value) - @patch("ray_provider.hooks.ray.RayHook._get_service") + @patch("ray_provider.hooks.RayHook._get_service") def test_wait_for_load_balancer_service_not_found(self, mock_get_service, ray_hook): mock_get_service.side_effect = AirflowException("Service test-service not found") @@ -399,7 +399,7 @@ def test_wait_for_load_balancer_service_not_found(self, mock_get_service, ray_ho assert "LoadBalancer did not become ready after 1 attempts" in str(exc_info.value) - @patch("ray_provider.hooks.ray.RayHook._is_port_open") + @patch("ray_provider.hooks.RayHook._is_port_open") def test_check_load_balancer_readiness_ip(self, mock_is_port_open, ray_hook): mock_is_port_open.return_value = True lb_details = {"ip": "192.168.1.1", "hostname": None, "ports": [{"name": "http", "port": 80}]} @@ -409,7 +409,7 @@ def test_check_load_balancer_readiness_ip(self, mock_is_port_open, ray_hook): assert result == "192.168.1.1" mock_is_port_open.assert_called_once_with("192.168.1.1", 80) - @patch("ray_provider.hooks.ray.RayHook._is_port_open") + @patch("ray_provider.hooks.RayHook._is_port_open") def test_check_load_balancer_readiness_hostname(self, mock_is_port_open, ray_hook): mock_is_port_open.side_effect = [False, True] lb_details = { @@ -424,7 +424,7 @@ def test_check_load_balancer_readiness_hostname(self, mock_is_port_open, ray_hoo mock_is_port_open.assert_any_call("192.168.1.1", 80) mock_is_port_open.assert_any_call("example.com", 80) - @patch("ray_provider.hooks.ray.RayHook._is_port_open") + @patch("ray_provider.hooks.RayHook._is_port_open") def test_check_load_balancer_readiness_not_ready(self, mock_is_port_open, ray_hook): mock_is_port_open.return_value = False lb_details = {"ip": "192.168.1.1", "hostname": "example.com", "ports": [{"name": "http", "port": 80}]} @@ -435,10 +435,10 @@ def test_check_load_balancer_readiness_not_ready(self, mock_is_port_open, ray_ho mock_is_port_open.assert_any_call("192.168.1.1", 80) mock_is_port_open.assert_any_call("example.com", 80) - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.read_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.read_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_get_daemon_set( self, mock_load_kube_config, mock_read_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -454,10 +454,10 @@ def test_get_daemon_set( assert daemon_set.metadata.name == "test-daemonset" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.read_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.read_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_get_daemon_set_not_found( self, mock_load_kube_config, mock_read_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -470,10 +470,10 @@ def test_get_daemon_set_not_found( assert daemon_set is None - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.create_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.create_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_create_daemon_set( self, mock_load_kube_config, mock_create_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -490,10 +490,10 @@ def test_create_daemon_set( assert daemon_set.metadata.name == "test-daemonset" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.create_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.create_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_create_daemon_set_no_body( self, mock_load_kube_config, mock_create_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -505,10 +505,10 @@ def test_create_daemon_set_no_body( assert daemon_set is None - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.create_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.create_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_create_daemon_set_exception( self, mock_load_kube_config, mock_create_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -522,10 +522,10 @@ def test_create_daemon_set_exception( assert daemon_set is None - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.delete_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.delete_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_delete_daemon_set( self, mock_load_kube_config, mock_delete_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -538,10 +538,10 @@ def test_delete_daemon_set( assert response.status == "Success" - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.delete_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.delete_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_delete_daemon_set_not_found( self, mock_load_kube_config, mock_delete_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -554,10 +554,10 @@ def test_delete_daemon_set_not_found( assert response is None - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") - @patch("ray_provider.hooks.ray.client.AppsV1Api.delete_namespaced_daemon_set") - @patch("ray_provider.hooks.ray.config.load_kube_config") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") + @patch("ray_provider.hooks.client.AppsV1Api.delete_namespaced_daemon_set") + @patch("ray_provider.hooks.config.load_kube_config") def test_delete_daemon_set_exception( self, mock_load_kube_config, mock_delete_daemon_set, mock_kubernetes_init, mock_get_connection ): @@ -570,8 +570,8 @@ def test_delete_daemon_set_exception( assert response is None - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") @patch("os.path.isfile") def test_validate_yaml_file_not_found(self, mock_is_file, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None @@ -584,8 +584,8 @@ def test_validate_yaml_file_not_found(self, mock_is_file, mock_kubernetes_init, assert "The specified YAML file does not exist" in str(exc_info.value) - @patch("ray_provider.hooks.ray.KubernetesHook.get_connection") - @patch("ray_provider.hooks.ray.KubernetesHook.__init__") + @patch("ray_provider.hooks.KubernetesHook.get_connection") + @patch("ray_provider.hooks.KubernetesHook.__init__") @patch("os.path.isfile") def test_validate_yaml_file_invalid_extension(self, mock_is_file, mock_kubernetes_init, mock_get_connection): mock_kubernetes_init.return_value = None @@ -598,13 +598,13 @@ def test_validate_yaml_file_invalid_extension(self, mock_is_file, mock_kubernete assert "The specified YAML file must have a .yaml or .yml extension" in str(exc_info.value) - @patch("ray_provider.hooks.ray.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.ray.RayHook.install_kuberay_operator") - @patch("ray_provider.hooks.ray.RayHook.load_yaml_content") - @patch("ray_provider.hooks.ray.RayHook.get_custom_object") - @patch("ray_provider.hooks.ray.RayHook.create_custom_object") - @patch("ray_provider.hooks.ray.RayHook._setup_gpu_driver") - @patch("ray_provider.hooks.ray.RayHook._setup_load_balancer") + @patch("ray_provider.hooks.RayHook._validate_yaml_file") + @patch("ray_provider.hooks.RayHook.install_kuberay_operator") + @patch("ray_provider.hooks.RayHook.load_yaml_content") + @patch("ray_provider.hooks.RayHook.get_custom_object") + @patch("ray_provider.hooks.RayHook.create_custom_object") + @patch("ray_provider.hooks.RayHook._setup_gpu_driver") + @patch("ray_provider.hooks.RayHook._setup_load_balancer") def test_setup_ray_cluster_success( self, mock_setup_load_balancer, @@ -639,13 +639,13 @@ def test_setup_ray_cluster_success( mock_setup_gpu_driver.assert_called_once_with(gpu_device_plugin_yaml="gpu.yaml") mock_setup_load_balancer.assert_called_once() - @patch("ray_provider.hooks.ray.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.ray.RayHook.uninstall_kuberay_operator") - @patch("ray_provider.hooks.ray.RayHook.load_yaml_content") - @patch("ray_provider.hooks.ray.RayHook.get_custom_object") - @patch("ray_provider.hooks.ray.RayHook.delete_custom_object") - @patch("ray_provider.hooks.ray.RayHook.get_daemon_set") - @patch("ray_provider.hooks.ray.RayHook.delete_daemon_set") + @patch("ray_provider.hooks.RayHook._validate_yaml_file") + @patch("ray_provider.hooks.RayHook.uninstall_kuberay_operator") + @patch("ray_provider.hooks.RayHook.load_yaml_content") + @patch("ray_provider.hooks.RayHook.get_custom_object") + @patch("ray_provider.hooks.RayHook.delete_custom_object") + @patch("ray_provider.hooks.RayHook.get_daemon_set") + @patch("ray_provider.hooks.RayHook.delete_daemon_set") def test_delete_ray_cluster_success( self, mock_delete_daemon_set, @@ -677,15 +677,15 @@ def test_delete_ray_cluster_success( mock_delete_custom_object.assert_called_once() mock_uninstall_kuberay_operator.assert_called_once() - @patch("ray_provider.hooks.ray.JobSubmissionClient") + @patch("ray_provider.hooks.JobSubmissionClient") def test_ray_client_exception(self, mock_job_client, ray_hook): mock_job_client.side_effect = Exception("Connection failed") with pytest.raises(AirflowException) as exc_info: ray_hook.ray_client() assert str(exc_info.value) == "Failed to create Ray JobSubmissionClient: Connection failed" - @patch("ray_provider.hooks.ray.RayHook.get_custom_object") - @patch("ray_provider.hooks.ray.RayHook.create_custom_object") + @patch("ray_provider.hooks.RayHook.get_custom_object") + @patch("ray_provider.hooks.RayHook.create_custom_object") def test_create_or_update_cluster_exception(self, mock_create, mock_get, ray_hook): mock_get.side_effect = client.exceptions.ApiException(status=500, reason="Internal Server Error") with pytest.raises(AirflowException) as exc_info: @@ -700,8 +700,8 @@ def test_create_or_update_cluster_exception(self, mock_create, mock_get, ray_hoo ) assert "Error accessing Ray cluster 'test-cluster'" in str(exc_info.value) - @patch("ray_provider.hooks.ray.RayHook.get_custom_object") - @patch("ray_provider.hooks.ray.RayHook.custom_object_client") + @patch("ray_provider.hooks.RayHook.get_custom_object") + @patch("ray_provider.hooks.RayHook.custom_object_client") def test_create_or_update_cluster_update(self, mock_client, mock_get, ray_hook): mock_get.return_value = {"metadata": {"name": "test-cluster"}} ray_hook._create_or_update_cluster( @@ -722,12 +722,12 @@ def test_create_or_update_cluster_update(self, mock_client, mock_get, ray_hook): body={"spec": {"some": "config"}}, ) - @patch("ray_provider.hooks.ray.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.ray.RayHook.install_kuberay_operator") - @patch("ray_provider.hooks.ray.RayHook.load_yaml_content") - @patch("ray_provider.hooks.ray.RayHook._create_or_update_cluster") - @patch("ray_provider.hooks.ray.RayHook._setup_gpu_driver") - @patch("ray_provider.hooks.ray.RayHook._setup_load_balancer") + @patch("ray_provider.hooks.RayHook._validate_yaml_file") + @patch("ray_provider.hooks.RayHook.install_kuberay_operator") + @patch("ray_provider.hooks.RayHook.load_yaml_content") + @patch("ray_provider.hooks.RayHook._create_or_update_cluster") + @patch("ray_provider.hooks.RayHook._setup_gpu_driver") + @patch("ray_provider.hooks.RayHook._setup_load_balancer") def test_setup_ray_cluster_exception( self, mock_setup_lb, @@ -750,13 +750,13 @@ def test_setup_ray_cluster_exception( ) assert "Failed to set up Ray cluster: Cluster creation failed" in str(exc_info.value) - @patch("ray_provider.hooks.ray.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.ray.RayHook.load_yaml_content") - @patch("ray_provider.hooks.ray.RayHook.get_custom_object") - @patch("ray_provider.hooks.ray.RayHook.delete_custom_object") - @patch("ray_provider.hooks.ray.RayHook.get_daemon_set") - @patch("ray_provider.hooks.ray.RayHook.delete_daemon_set") - @patch("ray_provider.hooks.ray.RayHook.uninstall_kuberay_operator") + @patch("ray_provider.hooks.RayHook._validate_yaml_file") + @patch("ray_provider.hooks.RayHook.load_yaml_content") + @patch("ray_provider.hooks.RayHook.get_custom_object") + @patch("ray_provider.hooks.RayHook.delete_custom_object") + @patch("ray_provider.hooks.RayHook.get_daemon_set") + @patch("ray_provider.hooks.RayHook.delete_daemon_set") + @patch("ray_provider.hooks.RayHook.uninstall_kuberay_operator") def test_delete_ray_cluster_exception( self, mock_uninstall_operator, diff --git a/tests/operators/test_ray_operators.py b/tests/test_ray_operators.py similarity index 96% rename from tests/operators/test_ray_operators.py rename to tests/test_ray_operators.py index 11d3d0c..72df6b5 100644 --- a/tests/operators/test_ray_operators.py +++ b/tests/test_ray_operators.py @@ -5,14 +5,14 @@ from airflow.exceptions import AirflowException, TaskDeferred from ray.job_submission import JobStatus -from ray_provider.operators.ray import DeleteRayCluster, SetupRayCluster, SubmitRayJob -from ray_provider.triggers.ray import RayJobTrigger +from ray_provider.operators import DeleteRayCluster, SetupRayCluster, SubmitRayJob +from ray_provider.triggers import RayJobTrigger class TestSetupRayCluster: @pytest.fixture def mock_hook(self): - with patch("ray_provider.operators.ray.RayHook") as mock: + with patch("ray_provider.operators.RayHook") as mock: yield mock.return_value @pytest.fixture @@ -48,7 +48,7 @@ def test_init_default_values(self): assert operator.update_if_exists is False def test_hook_property(self, operator): - with patch("ray_provider.operators.ray.RayHook") as mock_ray_hook: + with patch("ray_provider.operators.RayHook") as mock_ray_hook: hook = operator.hook mock_ray_hook.assert_called_once_with(conn_id=operator.conn_id) assert hook == mock_ray_hook.return_value @@ -68,7 +68,7 @@ def test_execute(self, operator, mock_hook): class TestDeleteRayCluster: @pytest.fixture def mock_hook(self): - with patch("ray_provider.operators.ray.RayHook") as mock: + with patch("ray_provider.operators.RayHook") as mock: yield mock.return_value @pytest.fixture @@ -98,7 +98,7 @@ def test_init_default_gpu_plugin(self): ) def test_hook_property(self, operator): - with patch("ray_provider.operators.ray.RayHook") as mock_ray_hook: + with patch("ray_provider.operators.RayHook") as mock_ray_hook: hook = operator.hook mock_ray_hook.assert_called_once_with(conn_id=operator.conn_id) assert hook == mock_ray_hook.return_value @@ -113,7 +113,7 @@ class TestSubmitRayJob: @pytest.fixture def mock_hook(self): - with patch("ray_provider.operators.ray.RayHook") as mock: + with patch("ray_provider.operators.RayHook") as mock: yield mock.return_value @pytest.fixture @@ -228,7 +228,7 @@ def test_get_dashboard_url_without_xcom(self, context): assert result is None - @patch("ray_provider.operators.ray.RayHook") + @patch("ray_provider.operators.RayHook") def test_setup_cluster(self, mock_ray_hook, context): operator = SubmitRayJob( task_id="test_task", @@ -254,7 +254,7 @@ def test_setup_cluster(self, mock_ray_hook, context): update_if_exists=True, ) - @patch("ray_provider.operators.ray.RayHook") + @patch("ray_provider.operators.RayHook") def test_delete_cluster(self, mock_ray_hook): operator = SubmitRayJob( task_id="test_task", @@ -371,7 +371,7 @@ def test_template_fields(self): "job_timeout_seconds", ) - @patch("ray_provider.operators.ray.RayHook") + @patch("ray_provider.operators.RayHook") def test_setup_cluster_exception(self, mock_ray_hook, context): operator = SubmitRayJob( task_id="test_task", @@ -392,7 +392,7 @@ def test_setup_cluster_exception(self, mock_ray_hook, context): assert str(exc_info.value) == "Cluster setup failed" mock_hook.setup_ray_cluster.assert_called_once() - @patch("ray_provider.operators.ray.RayHook") + @patch("ray_provider.operators.RayHook") def test_delete_cluster_exception(self, mock_ray_hook): operator = SubmitRayJob( task_id="test_task", diff --git a/tests/triggers/test_ray_triggers.py b/tests/test_ray_triggers.py similarity index 85% rename from tests/triggers/test_ray_triggers.py rename to tests/test_ray_triggers.py index f82b521..f97611e 100644 --- a/tests/triggers/test_ray_triggers.py +++ b/tests/test_ray_triggers.py @@ -5,7 +5,7 @@ from airflow.triggers.base import TriggerEvent from ray.job_submission import JobStatus -from ray_provider.triggers.ray import RayJobTrigger +from ray_provider.triggers import RayJobTrigger class TestRayJobTrigger: @@ -22,8 +22,8 @@ def trigger(self): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_no_job_id(self, mock_hook, mock_is_terminal): mock_is_terminal.return_value = True mock_hook.get_ray_job_status.return_value = JobStatus.FAILED @@ -42,8 +42,8 @@ async def test_run_no_job_id(self, mock_hook, mock_is_terminal): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): mock_is_terminal.side_effect = [False, True] mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED @@ -66,8 +66,8 @@ async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger): mock_is_terminal.side_effect = [False, True] mock_hook.get_ray_job_status.return_value = JobStatus.STOPPED @@ -84,8 +84,8 @@ async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger): mock_is_terminal.side_effect = [False, True] mock_hook.get_ray_job_status.return_value = JobStatus.FAILED @@ -102,9 +102,9 @@ async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") - @patch("ray_provider.triggers.ray.RayJobTrigger._stream_logs") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger._stream_logs") async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_is_terminal, trigger): mock_is_terminal.side_effect = [False, True] mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED @@ -123,7 +123,7 @@ async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_is ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_stream_logs(self, mock_hook, trigger): # Create a mock async iterator async def mock_async_iterator(): @@ -133,7 +133,7 @@ async def mock_async_iterator(): # Set up the mock to return an async iterator mock_hook.get_ray_tail_logs.return_value = mock_async_iterator() - with patch("ray_provider.triggers.ray.RayJobTrigger.log") as mock_log: + with patch("ray_provider.triggers.RayJobTrigger.log") as mock_log: await trigger._stream_logs() mock_log.info.assert_any_call("::group::test_job_id logs") @@ -144,7 +144,7 @@ async def mock_async_iterator(): def test_serialize(self, trigger): serialized = trigger.serialize() assert serialized == ( - "ray_provider.triggers.ray.RayJobTrigger", + "ray_provider.triggers.RayJobTrigger", { "job_id": "test_job_id", "conn_id": "test_conn", @@ -157,7 +157,7 @@ def test_serialize(self, trigger): ) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger.hook") async def test_is_terminal_state(self, mock_hook, trigger): mock_hook.get_ray_job_status.side_effect = [ JobStatus.PENDING, @@ -212,7 +212,7 @@ async def test_cleanup_with_exception(self, mock_log_error, mock_hook, trigger): @pytest.mark.asyncio @patch("asyncio.sleep", new_callable=AsyncMock) - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") async def test_poll_status(self, mock_is_terminal, mock_sleep, trigger): mock_is_terminal.side_effect = [False, False, True] @@ -222,9 +222,9 @@ async def test_poll_status(self, mock_is_terminal, mock_sleep, trigger): mock_sleep.assert_called_with(1) @pytest.mark.asyncio - @patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.ray.RayJobTrigger.hook") - @patch("ray_provider.triggers.ray.RayJobTrigger.cleanup") + @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") + @patch("ray_provider.triggers.RayJobTrigger.hook") + @patch("ray_provider.triggers.RayJobTrigger.cleanup") async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_is_terminal, trigger): mock_is_terminal.side_effect = Exception("Test exception") diff --git a/tests/triggers/__init__.py b/tests/triggers/__init__.py deleted file mode 100644 index e69de29..0000000 From c17fed4cb8db227ae50e99d3b63f9e734a0eb1ed Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Wed, 27 Nov 2024 09:47:27 +0000 Subject: [PATCH 2/6] Release 0.3.0a7 --- ray_provider/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ray_provider/__init__.py b/ray_provider/__init__.py index 04fb8c2..276b0d6 100644 --- a/ray_provider/__init__.py +++ b/ray_provider/__init__.py @@ -1,6 +1,8 @@ from __future__ import annotations -__version__ = "0.2.1" + +__version__ = "0.3.0a7" + from typing import Any From 09ef5d7714dbb186ce9df43f865d000fdb49572f Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Wed, 27 Nov 2024 10:08:39 +0000 Subject: [PATCH 3/6] Update the changelog with hcanges since 0.21 --- CHANGELOG.rst | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8e0e244..720cf0d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,56 @@ CHANGELOG ========= + +0.3.0a7 (2024-11-27) +------------------ + +**Breaking changes** + +* Simplify the project structure and debugging by @tatiana in #93 + +In order to improve the development and troubleshooting DAGs created with this provider, we introduced breaking changes +to the folder structure. It was flattened and the import paths to existing decorators, hooks, operators and trigger +changed, as documented in the table below: + ++-----------+---------------------------------------------+-----------------------------------------+ +| Type | Previous import path | Current import path | ++===========+=============================================+=========================================+ +| Decorator | ray_provider.decorators.ray.ray | ray_provider.decorators.ray | +| Hook | ray_provider.hooks.ray.RayHook | ray_provider.hooks.RayHook | +| Operator | ray_provider.operators.ray.DeleteRayCluster | ray_provider.operators.DeleteRayCluster | +| Operator | ray_provider.operators.ray.SetupRayCluster | ray_provider.operators.SetupRayCluster | +| Operator | ray_provider.operators.ray.SubmitRayJob | ray_provider.operators.SubmitRayJob | +| Trigger | ray_provider.triggers.ray.RayJobTrigger | ray_provider.triggers.RayJobTrigger | ++-----------+---------------------------------------------+-----------------------------------------+ + +**Features** + +* Dynamic configuration support by @tatiana in #94 (TODO: change) +* Support running Ray jobs indefinitely without timing out by @venkatajagannath and @tatiana in #74 + +**Bug fixes** + +* Fix integration test and bug in load balancer wait logic by @pankajastro in #85 +* Bugfix: Better exception handling and cluster clean up by @venkatajagannath in #68 + +**Docs** + +* Add docs to deploy project on Astro Cloud by @pankajastro in #90 +* Fix dead reference in docs index page by @pankajastro in #87 +* Cloud Auth documentation update by @venkatajagannath in #58 +* Improve main docs page by @TJaniF in #71 + +**Others** + +* Fix the local development environment and update documentation by @tatiana in #92 +* Enable secrect detection precommit check by @pankajastro in #91 +* Add astro cli project + kind Raycluster setup instruction by @pankajastro in #83 +* Update CODEOWNERS by @tatiana in #84 +* Allow tests to run for PRs from forked repos by @venkatajagannath in #72 +* CI improvement by @venkatajagannath in #73 +* CI fix related to broken coverage upload artifact by @pankajkoti in #60 + 0.2.1 (2024-09-04) ------------------ From f28d8ad45681ed6bf4f901992af1090c085c574b Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Thu, 28 Nov 2024 00:00:13 +0000 Subject: [PATCH 4/6] Remove redundant tests --- tests/test_ray_decorators.py | 192 --------- tests/test_ray_hooks.py | 774 ----------------------------------- tests/test_ray_operators.py | 593 --------------------------- tests/test_ray_triggers.py | 241 ----------- 4 files changed, 1800 deletions(-) delete mode 100644 tests/test_ray_decorators.py delete mode 100644 tests/test_ray_hooks.py delete mode 100644 tests/test_ray_operators.py delete mode 100644 tests/test_ray_triggers.py diff --git a/tests/test_ray_decorators.py b/tests/test_ray_decorators.py deleted file mode 100644 index 115a4ca..0000000 --- a/tests/test_ray_decorators.py +++ /dev/null @@ -1,192 +0,0 @@ -from datetime import timedelta -from unittest.mock import MagicMock, patch - -import pytest -from airflow.exceptions import AirflowException -from airflow.utils.context import Context - -from ray_provider.decorators import _RayDecoratedOperator -from ray_provider.decorators import ray as ray_decorator - - -class TestRayDecoratedOperator: - def test_initialization(self): - config = { - "conn_id": "ray_default", - "entrypoint": "python my_script.py", - "runtime_env": {"pip": ["ray"]}, - "num_cpus": 2, - "num_gpus": 1, - "memory": 1024, - "resources": {"custom_resource": 1}, - "fetch_logs": True, - "wait_for_completion": True, - "job_timeout_seconds": 300, - "poll_interval": 30, - "xcom_task_key": "ray_result", - } - - def dummy_callable(): - pass - - operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) - - assert operator.conn_id == "ray_default" - assert operator.entrypoint == "python my_script.py" - assert operator.runtime_env == {"pip": ["ray"]} - assert operator.num_cpus == 2 - assert operator.num_gpus == 1 - assert operator.memory == 1024 - assert operator.ray_resources == {"custom_resource": 1} - assert operator.fetch_logs == True - assert operator.wait_for_completion == True - assert operator.job_timeout_seconds == timedelta(seconds=300) - assert operator.poll_interval == 30 - assert operator.xcom_task_key == "ray_result" - - def test_initialization_defaults(self): - config = {} - - def dummy_callable(): - pass - - operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) - - assert operator.conn_id == "" - assert operator.entrypoint == "python script.py" - assert operator.runtime_env == {} - assert operator.num_cpus == 1 - assert operator.num_gpus == 0 - assert operator.memory is None - assert operator.ray_resources is None - assert operator.fetch_logs == True - assert operator.wait_for_completion == True - assert operator.job_timeout_seconds == timedelta(seconds=600) - assert operator.poll_interval == 60 - assert operator.xcom_task_key is None - - def test_invalid_config_raises_exception(self): - config = { - "num_cpus": "invalid_number", - } - - def dummy_callable(): - pass - - with pytest.raises(TypeError): - _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) - - config["num_cpus"] = 1 - config["num_gpus"] = "invalid_number" - with pytest.raises(TypeError): - _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) - - @patch.object(_RayDecoratedOperator, "_extract_function_body") - @patch("ray_provider.decorators.SubmitRayJob.execute") - def test_execute_decorated_function(self, mock_super_execute, mock_extract_function_body): - config = { - "runtime_env": {"pip": ["ray"]}, - } - - def dummy_callable(): - pass - - context = MagicMock(spec=Context) - operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) - mock_extract_function_body.return_value = "def dummy_callable():\n pass\n" - mock_super_execute.return_value = "success" - - result = operator.execute(context) - - assert result == "success" - assert operator.entrypoint == "python script.py" - assert "working_dir" in operator.runtime_env - - @patch("ray_provider.decorators.SubmitRayJob.execute") - def test_execute_with_entrypoint(self, mock_super_execute): - config = { - "entrypoint": "python my_script.py", - } - - def dummy_callable(): - pass - - context = MagicMock(spec=Context) - operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) - mock_super_execute.return_value = "success" - - result = operator.execute(context) - - assert result == "success" - assert operator.entrypoint == "python my_script.py" - - @patch("ray_provider.decorators.SubmitRayJob.execute") - def test_execute_failure(self, mock_super_execute): - config = {} - - def dummy_callable(): - pass - - context = MagicMock(spec=Context) - operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) - mock_super_execute.side_effect = Exception("Ray job failed") - - with pytest.raises(AirflowException): - operator.execute(context) - - def test_extract_function_body(self): - config = {} - - @ray_decorator.task() - def dummy_callable(): - return "dummy" - - operator = _RayDecoratedOperator(task_id="test_task", config=config, python_callable=dummy_callable) - - function_body = operator._extract_function_body( - """@ray_decorator.task() - def dummy_callable(): - return "dummy" - """ - ) - assert ( - function_body - == """def dummy_callable(): - return "dummy" -""" - ) - - -class TestRayTaskDecorator: - def test_ray_task_decorator(self): - @ray_decorator.task() - def dummy_function(): - return "dummy" - - assert callable(dummy_function) - assert hasattr(dummy_function, "operator_class") - assert dummy_function.operator_class == _RayDecoratedOperator - - def test_ray_task_decorator_with_multiple_outputs(self): - @ray_decorator.task(multiple_outputs=True) - def dummy_function(): - return {"key": "value"} - - assert callable(dummy_function) - assert hasattr(dummy_function, "operator_class") - assert dummy_function.operator_class == _RayDecoratedOperator - - def test_ray_task_decorator_with_config(self): - config = { - "num_cpus": 2, - "num_gpus": 1, - "memory": 1024, - } - - @ray_decorator.task(**config) - def dummy_function(): - return "dummy" - - assert callable(dummy_function) - assert hasattr(dummy_function, "operator_class") - assert dummy_function.operator_class == _RayDecoratedOperator diff --git a/tests/test_ray_hooks.py b/tests/test_ray_hooks.py deleted file mode 100644 index cfa33b7..0000000 --- a/tests/test_ray_hooks.py +++ /dev/null @@ -1,774 +0,0 @@ -import subprocess -from unittest.mock import MagicMock, Mock, mock_open, patch - -import pytest -import yaml -from airflow.exceptions import AirflowException -from kubernetes import client -from kubernetes.client.exceptions import ApiException -from ray.job_submission import JobStatus - -from ray_provider.hooks import RayHook - - -class TestRayHook: - - @pytest.fixture - def ray_hook(self): - with patch("ray_provider.hooks.KubernetesHook.get_connection") as mock_get_connection: - mock_connection = Mock() - mock_connection.extra_dejson = {"kube_config_path": None, "kube_config": None, "cluster_context": None} - mock_get_connection.return_value = mock_connection - - with patch("ray_provider.hooks.KubernetesHook.__init__", return_value=None): - hook = RayHook(conn_id="test_conn") - # Manually set the necessary attributes - hook.namespace = "default" - hook.kubeconfig = "/path/to/kubeconfig" - return hook - - def test_get_ui_field_behaviour(self): - expected_fields = { - "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], - "relabeling": {}, - } - assert RayHook.get_ui_field_behaviour() == expected_fields - - def test_get_connection_form_widgets(self): - widgets = RayHook.get_connection_form_widgets() - assert "address" in widgets - assert "kube_config_path" in widgets - assert "namespace" in widgets - - @patch("ray_provider.hooks.JobSubmissionClient") - def test_ray_client(self, mock_job_client, ray_hook): - mock_job_client.return_value = MagicMock() - client = ray_hook.ray_client() - assert isinstance(client, MagicMock) - mock_job_client.assert_called_once_with( - address=ray_hook.address, - create_cluster_if_needed=ray_hook.create_cluster_if_needed, - cookies=ray_hook.cookies, - metadata=ray_hook.metadata, - headers=ray_hook.headers, - verify=ray_hook.verify, - ) - - @patch("ray_provider.hooks.JobSubmissionClient") - def test_submit_ray_job(self, mock_job_client, ray_hook): - mock_client_instance = mock_job_client.return_value - mock_client_instance.submit_job.return_value = "test_job_id" - job_id = ray_hook.submit_ray_job(dashboard_url="http://example.com", entrypoint="test_entry") - assert job_id == "test_job_id" - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.config.load_kube_config") - def test_setup_kubeconfig_path(self, mock_load_kube_config, mock_kubernetes_init, mock_get_connection): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - - hook = RayHook(conn_id="test_conn") - hook._setup_kubeconfig("/tmp/fake_kubeconfig", None, "test_context") - - assert hook.kubeconfig == "/tmp/fake_kubeconfig" - mock_load_kube_config.assert_called_once_with(config_file="/tmp/fake_kubeconfig", context="test_context") - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.config.load_kube_config") - @patch("tempfile.NamedTemporaryFile") - def test_setup_kubeconfig_content( - self, mock_tempfile, mock_load_kube_config, mock_kubernetes_init, mock_get_connection - ): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - - mock_tempfile.return_value.__enter__.return_value.name = "/tmp/fake_kubeconfig" - mock_tempfile.return_value.__enter__.return_value.write = MagicMock() - - hook = RayHook(conn_id="test_conn") - kubeconfig_content = "apiVersion: v1\nclusters:\n- cluster:\n server: https://127.0.0.1:6443" - - hook._setup_kubeconfig(None, kubeconfig_content, "test_context") - - mock_tempfile.return_value.__enter__.return_value.write.assert_called_once_with(kubeconfig_content.encode()) - mock_load_kube_config.assert_called_once_with(config_file="/tmp/fake_kubeconfig", context="test_context") - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - def test_setup_kubeconfig_invalid_config(self, mock_kubernetes_init, mock_get_connection): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - - hook = RayHook(conn_id="test_conn") - with pytest.raises(AirflowException) as exc_info: - hook._setup_kubeconfig("/tmp/fake_kubeconfig", "kubeconfig_content", "test_context") - - assert str(exc_info.value) == ( - "Invalid connection configuration. Options kube_config_path and " - "kube_config are mutually exclusive. You can only use one option at a time." - ) - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.JobSubmissionClient") - def test_delete_ray_job(self, mock_job_client, mock_get_connection): - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_client_instance = mock_job_client.return_value - mock_client_instance.delete_job.return_value = "deleted" - hook = RayHook(conn_id="test_conn") - result = hook.delete_ray_job("http://example.com", job_id="test_job_id") - assert result == "deleted" - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.JobSubmissionClient") - def test_get_ray_job_status(self, mock_job_client, mock_get_connection): - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_client_instance = mock_job_client.return_value - mock_client_instance.get_job_status.return_value = JobStatus.SUCCEEDED - hook = RayHook(conn_id="test_conn") - status = hook.get_ray_job_status("http://example.com", "test_job_id") - assert status == JobStatus.SUCCEEDED - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.JobSubmissionClient") - def test_get_ray_job_logs(self, mock_job_client, mock_get_connection): - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_client_instance = mock_job_client.return_value - mock_client_instance.get_job_logs.return_value = "test logs" - - hook = RayHook(conn_id="test_conn") - dashboard_url = "http://example.com:8265" - job_id = "test_job_id" - - logs = hook.get_ray_job_logs(dashboard_url, job_id) - - assert logs == "test logs" - mock_job_client.assert_called_once_with( - address=dashboard_url, - create_cluster_if_needed=False, - cookies=None, - metadata=None, - headers=None, - verify=False, - ) - mock_client_instance.get_job_logs.assert_called_once_with(job_id=job_id) - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.requests.get") - @patch("builtins.open", new_callable=mock_open, read_data="key: value\n") - def test_load_yaml_content(self, mock_open, mock_requests, mock_get_connection): - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - hook = RayHook(conn_id="test_conn") - result = hook.load_yaml_content("test_path") - assert result == {"key": "value"} - - mock_requests.return_value.status_code = 200 - mock_requests.return_value.text = "key: value\n" - result = hook.load_yaml_content("http://test-url") - assert result == {"key": "value"} - - @patch("os.path.isfile") - @patch("builtins.open", new_callable=mock_open, read_data="key: value\n") - def test_validate_yaml_file_success(self, mock_file, mock_isfile, ray_hook): - mock_isfile.return_value = True - - # Test with a valid YAML file - ray_hook._validate_yaml_file("valid_file.yaml") - - mock_isfile.assert_called_once_with("valid_file.yaml") - mock_file.assert_called_once_with("valid_file.yaml") - - @patch("os.path.isfile") - @patch("builtins.open", new_callable=mock_open, read_data="invalid: yaml: content") - def test_validate_yaml_file_invalid_yaml(self, mock_file, mock_isfile, ray_hook): - mock_isfile.return_value = True - - # Test with an invalid YAML file - with pytest.raises(AirflowException) as exc_info: - with patch("yaml.safe_load", side_effect=yaml.YAMLError("Invalid YAML")): - ray_hook._validate_yaml_file("invalid_file.yaml") - - assert "The specified YAML file is not valid YAML" in str(exc_info.value) - mock_isfile.assert_called_once_with("invalid_file.yaml") - mock_file.assert_called_once_with("invalid_file.yaml") - - @patch("os.path.isfile") - def test_validate_yaml_file_not_exists(self, mock_isfile, ray_hook): - mock_isfile.return_value = False - - with pytest.raises(AirflowException) as exc_info: - ray_hook._validate_yaml_file("non_existent_file.yaml") - - assert "The specified YAML file does not exist" in str(exc_info.value) - mock_isfile.assert_called_once_with("non_existent_file.yaml") - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.socket.socket") - def test_is_port_open(self, mock_socket, mock_get_connection): - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_socket_instance = mock_socket.return_value - - # Test successful connection - mock_socket_instance.connect.return_value = None - hook = RayHook(conn_id="test_conn") - result = hook._is_port_open("localhost", 8080) - assert result is True - - @patch("ray_provider.hooks.RayHook.core_v1_client") - def test_get_service_success(self, mock_core_v1_client, ray_hook): - mock_service = Mock(spec=client.V1Service) - mock_core_v1_client.read_namespaced_service.return_value = mock_service - - service = ray_hook._get_service("test-service", "default") - - assert service == mock_service - mock_core_v1_client.read_namespaced_service.assert_called_once_with("test-service", "default") - - @patch("ray_provider.hooks.RayHook.core_v1_client") - def test_get_service_not_found(self, mock_core_v1_client, ray_hook): - mock_core_v1_client.read_namespaced_service.side_effect = client.exceptions.ApiException(status=404) - - with pytest.raises(AirflowException) as exc_info: - ray_hook._get_service("non-existent-service", "default") - - assert "Service non-existent-service not found" in str(exc_info.value) - - def test_get_load_balancer_details_with_ingress(self, ray_hook): - mock_service = Mock(spec=client.V1Service) - mock_ingress = Mock(spec=client.V1LoadBalancerIngress) - mock_ingress.ip = "192.168.1.1" - mock_ingress.hostname = None - mock_service.status.load_balancer.ingress = [mock_ingress] - - mock_port = Mock() - mock_port.name = "http" - mock_port.port = 80 - mock_service.spec.ports = [mock_port] - - lb_details = ray_hook._get_load_balancer_details(mock_service) - - assert lb_details == {"ip": "192.168.1.1", "hostname": None, "ports": [{"name": "http", "port": 80}]} - - def test_get_load_balancer_details_with_hostname(self, ray_hook): - mock_service = Mock(spec=client.V1Service) - mock_ingress = Mock(spec=client.V1LoadBalancerIngress) - mock_ingress.ip = None - mock_ingress.hostname = "example.com" - mock_service.status.load_balancer.ingress = [mock_ingress] - - mock_port = Mock() - mock_port.name = "https" - mock_port.port = 443 - mock_service.spec.ports = [mock_port] - - lb_details = ray_hook._get_load_balancer_details(mock_service) - - assert lb_details == {"hostname": "example.com", "ip": None, "ports": [{"name": "https", "port": 443}]} - - def test_get_load_balancer_details_no_ingress(self, ray_hook): - mock_service = Mock(spec=client.V1Service) - mock_service.status.load_balancer.ingress = None - - lb_details = ray_hook._get_load_balancer_details(mock_service) - - assert lb_details is None - - def test_get_load_balancer_details_no_ip_or_hostname(self, ray_hook): - mock_service = Mock(spec=client.V1Service) - mock_ingress = Mock(spec=client.V1LoadBalancerIngress) - mock_ingress.ip = None - mock_ingress.hostname = None - mock_service.status.load_balancer.ingress = [mock_ingress] - - lb_details = ray_hook._get_load_balancer_details(mock_service) - - assert lb_details is None - - @patch("ray_provider.hooks.RayHook.log") - @patch("ray_provider.hooks.subprocess.run") - def test_run_bash_command_exception(self, mock_subprocess_run, mock_log, ray_hook): - mock_subprocess_run.side_effect = subprocess.CalledProcessError( - returncode=1, cmd="test command", output="test output", stderr="test error" - ) - - stdout, stderr = ray_hook._run_bash_command("test command") - - assert stdout is None - assert stderr is None - - mock_log.error.assert_any_call( - "An error occurred while executing the command: %s", mock_subprocess_run.side_effect - ) - mock_log.error.assert_any_call("Return code: %s", 1) - mock_log.error.assert_any_call("Standard Output: %s", "test output") - mock_log.error.assert_any_call("Standard Error: %s", "test error") - - mock_subprocess_run.assert_called_once_with( - "test command", - shell=True, - check=True, - text=True, - capture_output=True, - env=ray_hook._run_bash_command.__globals__["os"].environ.copy(), - ) - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.subprocess.run") - def test_install_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_init, mock_get_connection): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_subprocess_run.return_value = MagicMock(stdout="install output", stderr="") - - hook = RayHook(conn_id="test_conn") - stdout, stderr = hook.install_kuberay_operator(version="1.0.0") - - assert "install output" in stdout - assert stderr == "" - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.subprocess.run") - def test_uninstall_kuberay_operator(self, mock_subprocess_run, mock_kubernetes_init, mock_get_connection): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_subprocess_run.return_value = MagicMock(stdout="uninstall output", stderr="") - - hook = RayHook(conn_id="test_conn") - stdout, stderr = hook.uninstall_kuberay_operator() - - assert "uninstall output" in stdout - assert stderr == "" - - @patch("ray_provider.hooks.RayHook._get_service") - @patch("ray_provider.hooks.RayHook._get_load_balancer_details") - @patch("ray_provider.hooks.RayHook._check_load_balancer_readiness") - def test_wait_for_load_balancer_success( - self, mock_check_readiness, mock_get_lb_details, mock_get_service, ray_hook - ): - mock_service = Mock(spec=client.V1Service) - mock_get_service.return_value = mock_service - - mock_get_lb_details.return_value = { - "hostname": "test-lb.example.com", - "ip": None, - "ports": [{"name": "http", "port": 80}, {"name": "https", "port": 443}], - } - - mock_check_readiness.return_value = "test-lb.example.com" - - result = ray_hook._wait_for_load_balancer("test-service", namespace="default", max_retries=1, retry_interval=1) - - assert result == { - "hostname": "test-lb.example.com", - "ip": None, - "ports": [{"name": "http", "port": 80}, {"name": "https", "port": 443}], - "working_address": "test-lb.example.com", - } - mock_get_service.assert_called_once_with("test-service", "default") - mock_get_lb_details.assert_called_once_with(mock_service) - mock_check_readiness.assert_called_once() - - @patch("ray_provider.hooks.RayHook._get_service") - @patch("ray_provider.hooks.RayHook._get_load_balancer_details") - @patch("ray_provider.hooks.RayHook._is_port_open") - def test_wait_for_load_balancer_timeout(self, mock_is_port_open, mock_get_lb_details, mock_get_service, ray_hook): - mock_service = Mock(spec=client.V1Service) - mock_get_service.return_value = mock_service - - mock_get_lb_details.return_value = { - "hostname": "test-lb.example.com", - "ip": None, - "ports": [{"name": "http", "port": 80}], - } - - mock_is_port_open.return_value = False - - with pytest.raises(AirflowException) as exc_info: - ray_hook._wait_for_load_balancer("test-service", namespace="default", max_retries=2, retry_interval=1) - - assert "LoadBalancer did not become ready after 2 attempts" in str(exc_info.value) - - @patch("ray_provider.hooks.RayHook._get_service") - def test_wait_for_load_balancer_service_not_found(self, mock_get_service, ray_hook): - mock_get_service.side_effect = AirflowException("Service test-service not found") - - with pytest.raises(AirflowException) as exc_info: - ray_hook._wait_for_load_balancer("test-service", namespace="default", max_retries=1, retry_interval=1) - - assert "LoadBalancer did not become ready after 1 attempts" in str(exc_info.value) - - @patch("ray_provider.hooks.RayHook._is_port_open") - def test_check_load_balancer_readiness_ip(self, mock_is_port_open, ray_hook): - mock_is_port_open.return_value = True - lb_details = {"ip": "192.168.1.1", "hostname": None, "ports": [{"name": "http", "port": 80}]} - - result = ray_hook._check_load_balancer_readiness(lb_details) - - assert result == "192.168.1.1" - mock_is_port_open.assert_called_once_with("192.168.1.1", 80) - - @patch("ray_provider.hooks.RayHook._is_port_open") - def test_check_load_balancer_readiness_hostname(self, mock_is_port_open, ray_hook): - mock_is_port_open.side_effect = [False, True] - lb_details = { - "ip": "192.168.1.1", - "hostname": "example.com", - "ports": [{"name": "http", "port": 80}, {"name": "https", "port": 443}], - } - - result = ray_hook._check_load_balancer_readiness(lb_details) - - assert result == "example.com" - mock_is_port_open.assert_any_call("192.168.1.1", 80) - mock_is_port_open.assert_any_call("example.com", 80) - - @patch("ray_provider.hooks.RayHook._is_port_open") - def test_check_load_balancer_readiness_not_ready(self, mock_is_port_open, ray_hook): - mock_is_port_open.return_value = False - lb_details = {"ip": "192.168.1.1", "hostname": "example.com", "ports": [{"name": "http", "port": 80}]} - - result = ray_hook._check_load_balancer_readiness(lb_details) - - assert result is None - mock_is_port_open.assert_any_call("192.168.1.1", 80) - mock_is_port_open.assert_any_call("example.com", 80) - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.client.AppsV1Api.read_namespaced_daemon_set") - @patch("ray_provider.hooks.config.load_kube_config") - def test_get_daemon_set( - self, mock_load_kube_config, mock_read_daemon_set, mock_kubernetes_init, mock_get_connection - ): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - - mock_metadata = MagicMock() - mock_metadata.name = "test-daemonset" - mock_read_daemon_set.return_value = MagicMock(metadata=mock_metadata) - - hook = RayHook(conn_id="test_conn") - daemon_set = hook.get_daemon_set(name="test-daemonset") - - assert daemon_set.metadata.name == "test-daemonset" - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.client.AppsV1Api.read_namespaced_daemon_set") - @patch("ray_provider.hooks.config.load_kube_config") - def test_get_daemon_set_not_found( - self, mock_load_kube_config, mock_read_daemon_set, mock_kubernetes_init, mock_get_connection - ): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_read_daemon_set.side_effect = ApiException(status=404, reason="Not Found") - - hook = RayHook(conn_id="test_conn") - daemon_set = hook.get_daemon_set(name="test-daemonset") - - assert daemon_set is None - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.client.AppsV1Api.create_namespaced_daemon_set") - @patch("ray_provider.hooks.config.load_kube_config") - def test_create_daemon_set( - self, mock_load_kube_config, mock_create_daemon_set, mock_kubernetes_init, mock_get_connection - ): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - - mock_metadata = MagicMock() - mock_metadata.name = "test-daemonset" - mock_create_daemon_set.return_value = MagicMock(metadata=mock_metadata) - - hook = RayHook(conn_id="test_conn") - body = {"metadata": {"name": "test-daemonset"}} - daemon_set = hook.create_daemon_set(name="test-daemonset", body=body) - - assert daemon_set.metadata.name == "test-daemonset" - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.client.AppsV1Api.create_namespaced_daemon_set") - @patch("ray_provider.hooks.config.load_kube_config") - def test_create_daemon_set_no_body( - self, mock_load_kube_config, mock_create_daemon_set, mock_kubernetes_init, mock_get_connection - ): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - - hook = RayHook(conn_id="test_conn") - daemon_set = hook.create_daemon_set(name="test-daemonset", body=None) - - assert daemon_set is None - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.client.AppsV1Api.create_namespaced_daemon_set") - @patch("ray_provider.hooks.config.load_kube_config") - def test_create_daemon_set_exception( - self, mock_load_kube_config, mock_create_daemon_set, mock_kubernetes_init, mock_get_connection - ): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_create_daemon_set.side_effect = ApiException(status=500, reason="Internal Server Error") - - hook = RayHook(conn_id="test_conn") - body = {"metadata": {"name": "test-daemonset"}} - daemon_set = hook.create_daemon_set(name="test-daemonset", body=body) - - assert daemon_set is None - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.client.AppsV1Api.delete_namespaced_daemon_set") - @patch("ray_provider.hooks.config.load_kube_config") - def test_delete_daemon_set( - self, mock_load_kube_config, mock_delete_daemon_set, mock_kubernetes_init, mock_get_connection - ): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_delete_daemon_set.return_value = MagicMock(status="Success") - - hook = RayHook(conn_id="test_conn") - response = hook.delete_daemon_set(name="test-daemonset") - - assert response.status == "Success" - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.client.AppsV1Api.delete_namespaced_daemon_set") - @patch("ray_provider.hooks.config.load_kube_config") - def test_delete_daemon_set_not_found( - self, mock_load_kube_config, mock_delete_daemon_set, mock_kubernetes_init, mock_get_connection - ): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_delete_daemon_set.side_effect = ApiException(status=404, reason="Not Found") - - hook = RayHook(conn_id="test_conn") - response = hook.delete_daemon_set(name="test-daemonset") - - assert response is None - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("ray_provider.hooks.client.AppsV1Api.delete_namespaced_daemon_set") - @patch("ray_provider.hooks.config.load_kube_config") - def test_delete_daemon_set_exception( - self, mock_load_kube_config, mock_delete_daemon_set, mock_kubernetes_init, mock_get_connection - ): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_delete_daemon_set.side_effect = ApiException(status=500, reason="Internal Server Error") - - hook = RayHook(conn_id="test_conn") - response = hook.delete_daemon_set(name="test-daemonset") - - assert response is None - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("os.path.isfile") - def test_validate_yaml_file_not_found(self, mock_is_file, mock_kubernetes_init, mock_get_connection): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_is_file.return_value = False - - hook = RayHook(conn_id="test_conn") - with pytest.raises(AirflowException) as exc_info: - hook._validate_yaml_file("test.yaml") - - assert "The specified YAML file does not exist" in str(exc_info.value) - - @patch("ray_provider.hooks.KubernetesHook.get_connection") - @patch("ray_provider.hooks.KubernetesHook.__init__") - @patch("os.path.isfile") - def test_validate_yaml_file_invalid_extension(self, mock_is_file, mock_kubernetes_init, mock_get_connection): - mock_kubernetes_init.return_value = None - mock_get_connection.return_value = MagicMock(conn_id="test_conn", extra_dejson={}) - mock_is_file.return_value = True - - hook = RayHook(conn_id="test_conn") - with pytest.raises(AirflowException) as exc_info: - hook._validate_yaml_file("test.txt") - - assert "The specified YAML file must have a .yaml or .yml extension" in str(exc_info.value) - - @patch("ray_provider.hooks.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.RayHook.install_kuberay_operator") - @patch("ray_provider.hooks.RayHook.load_yaml_content") - @patch("ray_provider.hooks.RayHook.get_custom_object") - @patch("ray_provider.hooks.RayHook.create_custom_object") - @patch("ray_provider.hooks.RayHook._setup_gpu_driver") - @patch("ray_provider.hooks.RayHook._setup_load_balancer") - def test_setup_ray_cluster_success( - self, - mock_setup_load_balancer, - mock_setup_gpu_driver, - mock_create_custom_object, - mock_get_custom_object, - mock_load_yaml_content, - mock_install_kuberay_operator, - mock_validate_yaml_file, - ray_hook, - ): - mock_load_yaml_content.return_value = { - "kind": "RayCluster", - "apiVersion": "ray.io/v1", - "metadata": {"name": "test-cluster"}, - } - mock_get_custom_object.side_effect = ApiException(status=404) - - context = {"task_instance": MagicMock()} - ray_hook.setup_ray_cluster( - context=context, - ray_cluster_yaml="test.yaml", - kuberay_version="1.0.0", - gpu_device_plugin_yaml="gpu.yaml", - update_if_exists=False, - ) - - mock_validate_yaml_file.assert_called_once_with("test.yaml") - mock_install_kuberay_operator.assert_called_once_with(version="1.0.0") - mock_load_yaml_content.assert_called_once_with("test.yaml") - mock_create_custom_object.assert_called_once() - mock_setup_gpu_driver.assert_called_once_with(gpu_device_plugin_yaml="gpu.yaml") - mock_setup_load_balancer.assert_called_once() - - @patch("ray_provider.hooks.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.RayHook.uninstall_kuberay_operator") - @patch("ray_provider.hooks.RayHook.load_yaml_content") - @patch("ray_provider.hooks.RayHook.get_custom_object") - @patch("ray_provider.hooks.RayHook.delete_custom_object") - @patch("ray_provider.hooks.RayHook.get_daemon_set") - @patch("ray_provider.hooks.RayHook.delete_daemon_set") - def test_delete_ray_cluster_success( - self, - mock_delete_daemon_set, - mock_get_daemon_set, - mock_delete_custom_object, - mock_get_custom_object, - mock_load_yaml_content, - mock_uninstall_kuberay_operator, - mock_validate_yaml_file, - ray_hook, - ): - mock_load_yaml_content.return_value = { - "kind": "RayCluster", - "apiVersion": "ray.io/v1", - "metadata": {"name": "test-cluster"}, - } - mock_get_daemon_set.return_value = MagicMock() - mock_get_custom_object.return_value = MagicMock() - - ray_hook.delete_ray_cluster( - ray_cluster_yaml="test.yaml", - gpu_device_plugin_yaml="gpu.yaml", - ) - - mock_validate_yaml_file.assert_called_once_with("test.yaml") - mock_load_yaml_content.assert_called_with("test.yaml") - mock_get_daemon_set.assert_called_once() - mock_delete_daemon_set.assert_called_once() - mock_delete_custom_object.assert_called_once() - mock_uninstall_kuberay_operator.assert_called_once() - - @patch("ray_provider.hooks.JobSubmissionClient") - def test_ray_client_exception(self, mock_job_client, ray_hook): - mock_job_client.side_effect = Exception("Connection failed") - with pytest.raises(AirflowException) as exc_info: - ray_hook.ray_client() - assert str(exc_info.value) == "Failed to create Ray JobSubmissionClient: Connection failed" - - @patch("ray_provider.hooks.RayHook.get_custom_object") - @patch("ray_provider.hooks.RayHook.create_custom_object") - def test_create_or_update_cluster_exception(self, mock_create, mock_get, ray_hook): - mock_get.side_effect = client.exceptions.ApiException(status=500, reason="Internal Server Error") - with pytest.raises(AirflowException) as exc_info: - ray_hook._create_or_update_cluster( - update_if_exists=False, - group="ray.io", - version="v1", - plural="rayclusters", - name="test-cluster", - namespace="default", - cluster_spec={}, - ) - assert "Error accessing Ray cluster 'test-cluster'" in str(exc_info.value) - - @patch("ray_provider.hooks.RayHook.get_custom_object") - @patch("ray_provider.hooks.RayHook.custom_object_client") - def test_create_or_update_cluster_update(self, mock_client, mock_get, ray_hook): - mock_get.return_value = {"metadata": {"name": "test-cluster"}} - ray_hook._create_or_update_cluster( - update_if_exists=True, - group="ray.io", - version="v1", - plural="rayclusters", - name="test-cluster", - namespace="default", - cluster_spec={"spec": {"some": "config"}}, - ) - mock_client.patch_namespaced_custom_object.assert_called_once_with( - group="ray.io", - version="v1", - namespace="default", - plural="rayclusters", - name="test-cluster", - body={"spec": {"some": "config"}}, - ) - - @patch("ray_provider.hooks.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.RayHook.install_kuberay_operator") - @patch("ray_provider.hooks.RayHook.load_yaml_content") - @patch("ray_provider.hooks.RayHook._create_or_update_cluster") - @patch("ray_provider.hooks.RayHook._setup_gpu_driver") - @patch("ray_provider.hooks.RayHook._setup_load_balancer") - def test_setup_ray_cluster_exception( - self, - mock_setup_lb, - mock_setup_gpu, - mock_create_or_update, - mock_load_yaml, - mock_install_operator, - mock_validate_yaml, - ray_hook, - ): - mock_create_or_update.side_effect = Exception("Cluster creation failed") - context = {"task_instance": MagicMock()} - with pytest.raises(AirflowException) as exc_info: - ray_hook.setup_ray_cluster( - context=context, - ray_cluster_yaml="test.yaml", - kuberay_version="1.0.0", - gpu_device_plugin_yaml="gpu.yaml", - update_if_exists=False, - ) - assert "Failed to set up Ray cluster: Cluster creation failed" in str(exc_info.value) - - @patch("ray_provider.hooks.RayHook._validate_yaml_file") - @patch("ray_provider.hooks.RayHook.load_yaml_content") - @patch("ray_provider.hooks.RayHook.get_custom_object") - @patch("ray_provider.hooks.RayHook.delete_custom_object") - @patch("ray_provider.hooks.RayHook.get_daemon_set") - @patch("ray_provider.hooks.RayHook.delete_daemon_set") - @patch("ray_provider.hooks.RayHook.uninstall_kuberay_operator") - def test_delete_ray_cluster_exception( - self, - mock_uninstall_operator, - mock_delete_daemon_set, - mock_get_daemon_set, - mock_delete_custom_object, - mock_get_custom_object, - mock_load_yaml, - mock_validate_yaml, - ray_hook, - ): - mock_delete_custom_object.side_effect = Exception("Cluster deletion failed") - with pytest.raises(AirflowException) as exc_info: - ray_hook.delete_ray_cluster(ray_cluster_yaml="test.yaml", gpu_device_plugin_yaml="gpu.yaml") - assert "Failed to delete Ray cluster: Cluster deletion failed" in str(exc_info.value) diff --git a/tests/test_ray_operators.py b/tests/test_ray_operators.py deleted file mode 100644 index 72df6b5..0000000 --- a/tests/test_ray_operators.py +++ /dev/null @@ -1,593 +0,0 @@ -from datetime import timedelta -from unittest.mock import MagicMock, Mock, patch - -import pytest -from airflow.exceptions import AirflowException, TaskDeferred -from ray.job_submission import JobStatus - -from ray_provider.operators import DeleteRayCluster, SetupRayCluster, SubmitRayJob -from ray_provider.triggers import RayJobTrigger - - -class TestSetupRayCluster: - @pytest.fixture - def mock_hook(self): - with patch("ray_provider.operators.RayHook") as mock: - yield mock.return_value - - @pytest.fixture - def operator(self): - return SetupRayCluster(task_id="test_setup_ray_cluster", conn_id="test_conn", ray_cluster_yaml="cluster.yaml") - - def test_init(self): - operator = SetupRayCluster( - task_id="test_setup_ray_cluster", - conn_id="test_conn", - ray_cluster_yaml="cluster.yaml", - kuberay_version="1.1.0", - gpu_device_plugin_yaml="custom_gpu_plugin.yaml", - update_if_exists=True, - ) - assert operator.conn_id == "test_conn" - assert operator.ray_cluster_yaml == "cluster.yaml" - assert operator.kuberay_version == "1.1.0" - assert operator.gpu_device_plugin_yaml == "custom_gpu_plugin.yaml" - assert operator.update_if_exists is True - - def test_init_default_values(self): - operator = SetupRayCluster( - task_id="test_setup_ray_cluster", - conn_id="test_conn", - ray_cluster_yaml="cluster.yaml", - ) - assert operator.kuberay_version == "1.0.0" - assert ( - operator.gpu_device_plugin_yaml - == "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml" - ) - assert operator.update_if_exists is False - - def test_hook_property(self, operator): - with patch("ray_provider.operators.RayHook") as mock_ray_hook: - hook = operator.hook - mock_ray_hook.assert_called_once_with(conn_id=operator.conn_id) - assert hook == mock_ray_hook.return_value - - def test_execute(self, operator, mock_hook): - context = MagicMock() - operator.execute(context) - mock_hook.setup_ray_cluster.assert_called_once_with( - context=context, - ray_cluster_yaml=operator.ray_cluster_yaml, - kuberay_version=operator.kuberay_version, - gpu_device_plugin_yaml=operator.gpu_device_plugin_yaml, - update_if_exists=operator.update_if_exists, - ) - - -class TestDeleteRayCluster: - @pytest.fixture - def mock_hook(self): - with patch("ray_provider.operators.RayHook") as mock: - yield mock.return_value - - @pytest.fixture - def operator(self): - return DeleteRayCluster(task_id="test_delete_ray_cluster", conn_id="test_conn", ray_cluster_yaml="cluster.yaml") - - def test_init(self): - operator = DeleteRayCluster( - task_id="test_delete_ray_cluster", - conn_id="test_conn", - ray_cluster_yaml="cluster.yaml", - gpu_device_plugin_yaml="custom_gpu_plugin.yaml", - ) - assert operator.conn_id == "test_conn" - assert operator.ray_cluster_yaml == "cluster.yaml" - assert operator.gpu_device_plugin_yaml == "custom_gpu_plugin.yaml" - - def test_init_default_gpu_plugin(self): - operator = DeleteRayCluster( - task_id="test_delete_ray_cluster", - conn_id="test_conn", - ray_cluster_yaml="cluster.yaml", - ) - assert ( - operator.gpu_device_plugin_yaml - == "https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml" - ) - - def test_hook_property(self, operator): - with patch("ray_provider.operators.RayHook") as mock_ray_hook: - hook = operator.hook - mock_ray_hook.assert_called_once_with(conn_id=operator.conn_id) - assert hook == mock_ray_hook.return_value - - def test_execute(self, operator, mock_hook): - context = MagicMock() - operator.execute(context) - mock_hook.delete_ray_cluster.assert_called_once_with(operator.ray_cluster_yaml, operator.gpu_device_plugin_yaml) - - -class TestSubmitRayJob: - - @pytest.fixture - def mock_hook(self): - with patch("ray_provider.operators.RayHook") as mock: - yield mock.return_value - - @pytest.fixture - def operator(self): - return SubmitRayJob(task_id="test_task", conn_id="test_conn", entrypoint="python script.py", runtime_env={}) - - @pytest.fixture - def task_instance(self): - return Mock() - - @pytest.fixture - def context(self, task_instance): - return {"ti": task_instance, "task": Mock()} - - def test_init(self): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={"pip": ["package1", "package2"]}, - num_cpus=2, - num_gpus=1, - memory=1000, - resources={"custom_resource": 1}, - ray_cluster_yaml="cluster.yaml", - kuberay_version="1.0.0", - update_if_exists=True, - gpu_device_plugin_yaml="https://example.com/plugin.yml", - fetch_logs=True, - wait_for_completion=True, - job_timeout_seconds=1200, - poll_interval=30, - xcom_task_key="task.key", - ) - - assert operator.conn_id == "test_conn" - assert operator.entrypoint == "python script.py" - assert operator.runtime_env == {"pip": ["package1", "package2"]} - assert operator.num_cpus == 2 - assert operator.num_gpus == 1 - assert operator.memory == 1000 - assert operator.ray_resources == {"custom_resource": 1} - assert operator.ray_cluster_yaml == "cluster.yaml" - assert operator.kuberay_version == "1.0.0" - assert operator.update_if_exists == True - assert operator.gpu_device_plugin_yaml == "https://example.com/plugin.yml" - assert operator.fetch_logs == True - assert operator.wait_for_completion == True - assert operator.job_timeout_seconds == timedelta(seconds=1200) - assert operator.poll_interval == 30 - assert operator.xcom_task_key == "task.key" - - def test_init_no_timeout(self): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={"pip": ["package1", "package2"]}, - num_cpus=2, - num_gpus=1, - memory=1000, - resources={"custom_resource": 1}, - ray_cluster_yaml="cluster.yaml", - kuberay_version="1.0.0", - update_if_exists=True, - gpu_device_plugin_yaml="https://example.com/plugin.yml", - fetch_logs=True, - wait_for_completion=True, - job_timeout_seconds=0, - poll_interval=30, - xcom_task_key="task.key", - ) - assert operator.job_timeout_seconds is None - - def test_on_kill(self, mock_hook): - operator = SubmitRayJob(task_id="test_task", conn_id="test_conn", entrypoint="python script.py", runtime_env={}) - operator.job_id = "test_job_id" - operator.hook = mock_hook - operator.dashboard_url = "http://dashboard.url" - operator.ray_cluster_yaml = "cluster.yaml" - - with patch.object(operator, "_delete_cluster") as mock_delete_cluster: - operator.on_kill() - - mock_hook.delete_ray_job.assert_called_once_with("http://dashboard.url", "test_job_id") - mock_delete_cluster.assert_called_once() - - def test_get_dashboard_url_with_xcom(self, context, task_instance): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - xcom_task_key="task.key", - ) - - task_instance.xcom_pull.return_value = "http://dashboard.url" - result = operator._get_dashboard_url(context) - - assert result == "http://dashboard.url" - task_instance.xcom_pull.assert_called_once_with(task_ids="task", key="key") - - def test_get_dashboard_url_without_xcom(self, context): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - ) - - result = operator._get_dashboard_url(context) - - assert result is None - - @patch("ray_provider.operators.RayHook") - def test_setup_cluster(self, mock_ray_hook, context): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - ray_cluster_yaml="cluster.yaml", - kuberay_version="1.0.0", - update_if_exists=True, - gpu_device_plugin_yaml="https://example.com/plugin.yml", - ) - - mock_hook = mock_ray_hook.return_value - operator.hook = mock_hook - - operator._setup_cluster(context) - - mock_hook.setup_ray_cluster.assert_called_once_with( - context=context, - ray_cluster_yaml="cluster.yaml", - kuberay_version="1.0.0", - gpu_device_plugin_yaml="https://example.com/plugin.yml", - update_if_exists=True, - ) - - @patch("ray_provider.operators.RayHook") - def test_delete_cluster(self, mock_ray_hook): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - ray_cluster_yaml="cluster.yaml", - gpu_device_plugin_yaml="https://example.com/plugin.yml", - ) - - mock_hook = mock_ray_hook.return_value - operator.hook = mock_hook - - operator._delete_cluster() - - mock_hook.delete_ray_cluster.assert_called_once_with( - ray_cluster_yaml="cluster.yaml", - gpu_device_plugin_yaml="https://example.com/plugin.yml", - ) - - def test_execute_without_wait(self, mock_hook, context): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - wait_for_completion=False, - ) - - mock_hook.submit_ray_job.return_value = "test_job_id" - - with patch.object(operator, "_setup_cluster") as mock_setup_cluster: - result = operator.execute(context) - - mock_setup_cluster.assert_called_once_with(context=context) - assert result == "test_job_id" - mock_hook.submit_ray_job.assert_called_once_with( - dashboard_url=None, - entrypoint="python script.py", - runtime_env={}, - entrypoint_num_cpus=0, - entrypoint_num_gpus=0, - entrypoint_memory=0, - entrypoint_resources=None, - ) - - @pytest.mark.parametrize( - "job_status,expected_action", - [ - (JobStatus.PENDING, "defer"), - (JobStatus.RUNNING, "defer"), - (JobStatus.SUCCEEDED, None), - (JobStatus.FAILED, "raise"), - (JobStatus.STOPPED, "raise"), - ], - ) - def test_execute_with_wait(self, mock_hook, context, job_status, expected_action): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - wait_for_completion=True, - ) - - mock_hook.submit_ray_job.return_value = "test_job_id" - mock_hook.get_ray_job_status.return_value = job_status - - with patch.object(operator, "_setup_cluster"): - if expected_action == "defer": - with patch.object(operator, "defer") as mock_defer: - operator.execute(context) - mock_defer.assert_called_once() - elif expected_action == "raise": - with pytest.raises(AirflowException): - operator.execute(context) - else: - result = operator.execute(context) - assert result == "test_job_id" - - @pytest.mark.parametrize( - "event_status,expected_action", - [ - (JobStatus.SUCCEEDED, None), - (JobStatus.FAILED, "raise"), - (JobStatus.STOPPED, "raise"), - ("UNEXPECTED", "raise"), - ], - ) - def test_execute_complete(self, operator, event_status, expected_action): - operator.job_id = "test_job_id" - event = {"status": event_status, "message": "Test message"} - - with patch.object(operator, "_delete_cluster") as mock_delete_cluster: - if expected_action == "raise": - with pytest.raises(AirflowException): - operator.execute_complete({}, event) - else: - operator.execute_complete({}, event) - - # _delete_cluster should be called in all cases - mock_delete_cluster.assert_called_once() - - def test_template_fields(self): - assert SubmitRayJob.template_fields == ( - "conn_id", - "entrypoint", - "runtime_env", - "num_cpus", - "num_gpus", - "memory", - "xcom_task_key", - "ray_cluster_yaml", - "job_timeout_seconds", - ) - - @patch("ray_provider.operators.RayHook") - def test_setup_cluster_exception(self, mock_ray_hook, context): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - ray_cluster_yaml="cluster.yaml", - ) - - mock_hook = mock_ray_hook.return_value - operator.hook = mock_hook - - mock_hook.setup_ray_cluster.side_effect = Exception("Cluster setup failed") - - with pytest.raises(Exception) as exc_info: - operator._setup_cluster(context) - - assert str(exc_info.value) == "Cluster setup failed" - mock_hook.setup_ray_cluster.assert_called_once() - - @patch("ray_provider.operators.RayHook") - def test_delete_cluster_exception(self, mock_ray_hook): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - ray_cluster_yaml="cluster.yaml", - ) - - mock_hook = mock_ray_hook.return_value - operator.hook = mock_hook - - mock_hook.delete_ray_cluster.side_effect = Exception("Cluster deletion failed") - - with pytest.raises(Exception) as exc_info: - operator._delete_cluster() - - assert str(exc_info.value) == "Cluster deletion failed" - mock_hook.delete_ray_cluster.assert_called_once() - - @pytest.mark.parametrize( - "xcom_task_key, expected_task, expected_key", - [ - ("task.key", "task", "key"), - ("single_key", None, "single_key"), - ], - ) - def test_get_dashboard_url_xcom_variants(self, operator, context, xcom_task_key, expected_task, expected_key): - operator.xcom_task_key = xcom_task_key - context["ti"].xcom_pull.return_value = "http://dashboard.url" - - result = operator._get_dashboard_url(context) - - assert result == "http://dashboard.url" - if expected_task: - context["ti"].xcom_pull.assert_called_once_with(task_ids=expected_task, key=expected_key) - else: - context["ti"].xcom_pull.assert_called_once_with(task_ids=context["task"].task_id, key=expected_key) - - def test_execute_job_unexpected_state(self, mock_hook, context): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - wait_for_completion=True, - ) - mock_hook.submit_ray_job.return_value = "test_job_id" - mock_hook.get_ray_job_status.return_value = "UNEXPECTED_STATE" - - with patch.object(operator, "_setup_cluster"), pytest.raises(TaskDeferred) as exc_info: - operator.execute(context) - - assert isinstance(exc_info.value.trigger, RayJobTrigger) - - @pytest.mark.parametrize("dashboard_url", [None, "http://dashboard.url"]) - def test_execute_defer(self, mock_hook, context, dashboard_url): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - wait_for_completion=True, - ray_cluster_yaml="cluster.yaml", - gpu_device_plugin_yaml="gpu_plugin.yaml", - poll_interval=30, - fetch_logs=True, - job_timeout_seconds=600, - ) - mock_hook.submit_ray_job.return_value = "test_job_id" - mock_hook.get_ray_job_status.return_value = JobStatus.PENDING - - with patch.object(operator, "_setup_cluster"), patch.object( - operator, "_get_dashboard_url", return_value=dashboard_url - ), pytest.raises(TaskDeferred) as exc_info: - operator.execute(context) - - trigger = exc_info.value.trigger - assert isinstance(trigger, RayJobTrigger) - assert trigger.job_id == "test_job_id" - assert trigger.conn_id == "test_conn" - assert trigger.dashboard_url == dashboard_url - assert trigger.ray_cluster_yaml == "cluster.yaml" - assert trigger.gpu_device_plugin_yaml == "gpu_plugin.yaml" - assert trigger.poll_interval == 30 - assert trigger.fetch_logs is True - - def test_execute_complete_unexpected_status(self, operator): - event = {"status": "UNEXPECTED", "message": "Unexpected status"} - with patch.object(operator, "_delete_cluster"), pytest.raises(AirflowException) as exc_info: - operator.execute_complete({}, event) - - assert "Unexpected event status" in str(exc_info.value) - - def test_execute_complete_cleanup_on_exception(self, operator): - event = {"status": JobStatus.FAILED, "message": "Job failed"} - with patch.object(operator, "_delete_cluster") as mock_delete_cluster, pytest.raises(AirflowException): - operator.execute_complete({}, event) - - mock_delete_cluster.assert_called_once() - - def test_execute_exception_handling(self, mock_hook, context): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - ray_cluster_yaml="cluster.yaml", - ) - - mock_hook.submit_ray_job.side_effect = Exception("Job submission failed") - - with patch.object(operator, "_setup_cluster"), patch.object( - operator, "_delete_cluster" - ) as mock_delete_cluster, pytest.raises(AirflowException) as exc_info: - operator.execute(context) - - assert "SubmitRayJob operator failed due to Job submission failed" in str(exc_info.value) - mock_delete_cluster.assert_called_once() - - def test_execute_cluster_setup_exception(self, mock_hook, context): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - ray_cluster_yaml="cluster.yaml", - ) - - with patch.object(operator, "_setup_cluster", side_effect=Exception("Cluster setup failed")), patch.object( - operator, "_delete_cluster" - ) as mock_delete_cluster, pytest.raises(AirflowException) as exc_info: - operator.execute(context) - - assert "SubmitRayJob operator failed due to Cluster setup failed" in str(exc_info.value) - mock_delete_cluster.assert_called_once() - - def test_execute_with_wait_and_defer(self, mock_hook, context): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - wait_for_completion=True, - poll_interval=30, - fetch_logs=True, - job_timeout_seconds=600, - ) - - mock_hook.submit_ray_job.return_value = "test_job_id" - mock_hook.get_ray_job_status.return_value = JobStatus.PENDING - - with patch.object(operator, "_setup_cluster"), patch.object(operator, "defer") as mock_defer: - operator.execute(context) - - mock_defer.assert_called_once() - args, kwargs = mock_defer.call_args - assert isinstance(kwargs["trigger"], RayJobTrigger) - assert kwargs["method_name"] == "execute_complete" - assert kwargs["timeout"].total_seconds() == 600 - - def test_execute_complete_with_cleanup(self, operator): - operator.job_id = "test_job_id" - event = {"status": JobStatus.FAILED, "message": "Job failed"} - - with patch.object(operator, "_delete_cluster") as mock_delete_cluster, pytest.raises(AirflowException): - operator.execute_complete({}, event) - - mock_delete_cluster.assert_called_once() - - def test_execute_without_wait_no_cleanup(self, mock_hook, context): - operator = SubmitRayJob( - task_id="test_task", - conn_id="test_conn", - entrypoint="python script.py", - runtime_env={}, - wait_for_completion=False, - ) - - mock_hook.submit_ray_job.return_value = "test_job_id" - - with patch.object(operator, "_setup_cluster") as mock_setup_cluster, patch.object( - operator, "_delete_cluster" - ) as mock_delete_cluster: - result = operator.execute(context) - - mock_setup_cluster.assert_called_once_with(context=context) - assert result == "test_job_id" - mock_hook.submit_ray_job.assert_called_once_with( - dashboard_url=None, - entrypoint="python script.py", - runtime_env={}, - entrypoint_num_cpus=0, - entrypoint_num_gpus=0, - entrypoint_memory=0, - entrypoint_resources=None, - ) - mock_delete_cluster.assert_not_called() diff --git a/tests/test_ray_triggers.py b/tests/test_ray_triggers.py deleted file mode 100644 index f97611e..0000000 --- a/tests/test_ray_triggers.py +++ /dev/null @@ -1,241 +0,0 @@ -import logging -from unittest.mock import AsyncMock, call, patch - -import pytest -from airflow.triggers.base import TriggerEvent -from ray.job_submission import JobStatus - -from ray_provider.triggers import RayJobTrigger - - -class TestRayJobTrigger: - @pytest.fixture - def trigger(self): - return RayJobTrigger( - job_id="test_job_id", - conn_id="test_conn", - xcom_dashboard_url="http://test-dashboard.com", - ray_cluster_yaml="test.yaml", - gpu_device_plugin_yaml="nvidia.yaml", - poll_interval=1, - fetch_logs=True, - ) - - @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_no_job_id(self, mock_hook, mock_is_terminal): - mock_is_terminal.return_value = True - mock_hook.get_ray_job_status.return_value = JobStatus.FAILED - trigger = RayJobTrigger( - job_id="", - poll_interval=1, - conn_id="test", - xcom_dashboard_url="test", - ray_cluster_yaml="test.yaml", - gpu_device_plugin_yaml="nvidia.yaml", - ) - generator = trigger.run() - event = await generator.asend(None) - assert event == TriggerEvent( - {"status": JobStatus.FAILED, "message": "Job completed with status FAILED", "job_id": ""} - ) - - @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_job_succeeded(self, mock_hook, mock_is_terminal): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED - trigger = RayJobTrigger( - job_id="test_job_id", - poll_interval=1, - conn_id="test", - xcom_dashboard_url="test", - ray_cluster_yaml="test.yaml", - gpu_device_plugin_yaml="nvidia.yaml", - ) - generator = trigger.run() - event = await generator.asend(None) - assert event == TriggerEvent( - { - "status": JobStatus.SUCCEEDED, - "message": f"Job test_job_id completed with status {JobStatus.SUCCEEDED}", - "job_id": "test_job_id", - } - ) - - @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_job_stopped(self, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.STOPPED - - generator = trigger.run() - event = await generator.asend(None) - - assert event == TriggerEvent( - { - "status": JobStatus.STOPPED, - "message": f"Job test_job_id completed with status {JobStatus.STOPPED}", - "job_id": "test_job_id", - } - ) - - @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_run_job_failed(self, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.FAILED - - generator = trigger.run() - event = await generator.asend(None) - - assert event == TriggerEvent( - { - "status": JobStatus.FAILED, - "message": f"Job test_job_id completed with status {JobStatus.FAILED}", - "job_id": "test_job_id", - } - ) - - @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.RayJobTrigger.hook") - @patch("ray_provider.triggers.RayJobTrigger._stream_logs") - async def test_run_with_log_streaming(self, mock_stream_logs, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = [False, True] - mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED - mock_stream_logs.return_value = None - - generator = trigger.run() - event = await generator.asend(None) - - mock_stream_logs.assert_called_once() - assert event == TriggerEvent( - { - "status": JobStatus.SUCCEEDED, - "message": f"Job test_job_id completed with status {JobStatus.SUCCEEDED}", - "job_id": "test_job_id", - } - ) - - @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_stream_logs(self, mock_hook, trigger): - # Create a mock async iterator - async def mock_async_iterator(): - for item in ["Log line 1\n", "Log line 2\n"]: - yield item - - # Set up the mock to return an async iterator - mock_hook.get_ray_tail_logs.return_value = mock_async_iterator() - - with patch("ray_provider.triggers.RayJobTrigger.log") as mock_log: - await trigger._stream_logs() - - mock_log.info.assert_any_call("::group::test_job_id logs") - mock_log.info.assert_any_call("Log line 1") - mock_log.info.assert_any_call("Log line 2") - mock_log.info.assert_any_call("::endgroup::") - - def test_serialize(self, trigger): - serialized = trigger.serialize() - assert serialized == ( - "ray_provider.triggers.RayJobTrigger", - { - "job_id": "test_job_id", - "conn_id": "test_conn", - "xcom_dashboard_url": "http://test-dashboard.com", - "ray_cluster_yaml": "test.yaml", - "gpu_device_plugin_yaml": "nvidia.yaml", - "fetch_logs": True, - "poll_interval": 1, - }, - ) - - @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger.hook") - async def test_is_terminal_state(self, mock_hook, trigger): - mock_hook.get_ray_job_status.side_effect = [ - JobStatus.PENDING, - JobStatus.RUNNING, - JobStatus.SUCCEEDED, - ] - - assert not trigger._is_terminal_state() - assert not trigger._is_terminal_state() - assert trigger._is_terminal_state() - - @pytest.mark.asyncio - @patch.object(RayJobTrigger, "hook") - @patch.object(logging.Logger, "info") - async def test_cleanup_with_cluster_yaml(self, mock_log_info, mock_hook, trigger): - await trigger.cleanup() - - mock_log_info.assert_has_calls( - [ - call("Attempting to delete Ray cluster using YAML: test.yaml"), - call("Ray cluster deletion process completed"), - ] - ) - mock_hook.delete_ray_cluster.assert_called_once_with("test.yaml", "nvidia.yaml") - - @pytest.mark.asyncio - @patch.object(logging.Logger, "info") - async def test_cleanup_without_cluster_yaml(self, mock_log_info): - trigger = RayJobTrigger( - job_id="test_job_id", - conn_id="test_conn", - xcom_dashboard_url="http://test-dashboard.com", - ray_cluster_yaml=None, - gpu_device_plugin_yaml="nvidia.yaml", - poll_interval=1, - fetch_logs=True, - ) - - await trigger.cleanup() - - mock_log_info.assert_called_once_with("No Ray cluster YAML provided, skipping cluster deletion") - - @pytest.mark.asyncio - @patch.object(RayJobTrigger, "hook") - @patch.object(logging.Logger, "error") - async def test_cleanup_with_exception(self, mock_log_error, mock_hook, trigger): - mock_hook.delete_ray_cluster.side_effect = Exception("Test exception") - - await trigger.cleanup() - - mock_log_error.assert_called_once_with("Unexpected error during cleanup: Test exception") - - @pytest.mark.asyncio - @patch("asyncio.sleep", new_callable=AsyncMock) - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") - async def test_poll_status(self, mock_is_terminal, mock_sleep, trigger): - mock_is_terminal.side_effect = [False, False, True] - - await trigger._poll_status() - - assert mock_sleep.call_count == 2 - mock_sleep.assert_called_with(1) - - @pytest.mark.asyncio - @patch("ray_provider.triggers.RayJobTrigger._is_terminal_state") - @patch("ray_provider.triggers.RayJobTrigger.hook") - @patch("ray_provider.triggers.RayJobTrigger.cleanup") - async def test_run_with_exception(self, mock_cleanup, mock_hook, mock_is_terminal, trigger): - mock_is_terminal.side_effect = Exception("Test exception") - - generator = trigger.run() - event = await generator.asend(None) - - assert event == TriggerEvent( - { - "status": str(JobStatus.FAILED), - "message": "Test exception", - "job_id": "test_job_id", - } - ) - mock_cleanup.assert_called_once() From 13ccc3bf0e3d0f2e2a5f5644170aac9d5d67f4da Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 29 Nov 2024 08:22:18 +0000 Subject: [PATCH 5/6] Re order things so it is easier for code review --- ray_provider/hooks.py | 437 +++++++++++++++++++----------------------- 1 file changed, 195 insertions(+), 242 deletions(-) diff --git a/ray_provider/hooks.py b/ray_provider/hooks.py index 196358d..560a325 100644 --- a/ray_provider/hooks.py +++ b/ray_provider/hooks.py @@ -5,6 +5,7 @@ import subprocess import tempfile import time +from functools import cached_property from typing import Any, AsyncIterator import requests @@ -17,6 +18,8 @@ from ray_provider.constants import TERMINAL_JOB_STATUSES +DEFAULT_NAMESPACE = "default" + class RayHook(KubernetesHook): # type: ignore """ @@ -33,7 +36,43 @@ class RayHook(KubernetesHook): # type: ignore conn_type = "ray" hook_name = "Ray" - DEFAULT_NAMESPACE = "default" + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + """ + Return custom field behaviour for the connection form. + + :return: A dictionary specifying custom field behaviour. + """ + return { + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], + "relabeling": {}, + } + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + """ + Return connection widgets to add to connection form. + + :return: A dictionary of connection form widgets. + """ + from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import BooleanField, PasswordField, StringField + + return { + "address": StringField(lazy_gettext("Ray dashboard url"), widget=BS3TextFieldWidget()), + # "create_cluster_if_needed": BooleanField(lazy_gettext("Create cluster if needed")), + "cookies": StringField(lazy_gettext("Cookies"), widget=BS3TextFieldWidget()), + "metadata": StringField(lazy_gettext("Metadata"), widget=BS3TextFieldWidget()), + "headers": StringField(lazy_gettext("Headers"), widget=BS3TextFieldWidget()), + "verify": BooleanField(lazy_gettext("Verify")), + "kube_config_path": StringField(lazy_gettext("Kube config path"), widget=BS3TextFieldWidget()), + "kube_config": PasswordField(lazy_gettext("Kube config (JSON format)"), widget=BS3PasswordFieldWidget()), + "namespace": StringField(lazy_gettext("Namespace"), widget=BS3TextFieldWidget()), + "cluster_context": StringField(lazy_gettext("Cluster context"), widget=BS3TextFieldWidget()), + "disable_verify_ssl": BooleanField(lazy_gettext("Disable SSL")), + "disable_tcp_keepalive": BooleanField(lazy_gettext("Disable TCP keepalive")), + } def __init__( self, @@ -56,7 +95,7 @@ def __init__( self.verify = self._get_field("verify") or False self.ray_client_instance = None - self.default_namespace = self.get_namespace() or self.DEFAULT_NAMESPACE + self.default_namespace = self.get_namespace() or DEFAULT_NAMESPACE self.kubeconfig: str | None = None self.in_cluster: bool | None = None self.client_configuration = None @@ -68,23 +107,23 @@ def __init__( self.cluster_context = self._get_field("cluster_context") self.kubeconfig_path = self._get_field("kube_config_path") self.kubeconfig_content = self._get_field("kube_config") - self.ray_cluster_yaml = None + self.ray_cluster_yaml: None | str = None self._setup_kubeconfig(self.kubeconfig_path, self.kubeconfig_content, self.cluster_context) - @property # TODO: cached property + # Create a PR for this + @cached_property def namespace(self): if self.ray_cluster_yaml is None: return self.default_namespace cluster_spec = self.load_yaml_content(self.ray_cluster_yaml) return cluster_spec["metadata"].get("namespace") or self.default_namespace + # Create another PR for this def test_connection(self): job_client = self.ray_client(self.address) - job_id = job_client.submit_job( - entrypoint="import ray; ray.init(); print(ray.cluster_resources())" - ) + job_id = job_client.submit_job(entrypoint="import ray; ray.init(); print(ray.cluster_resources())") self.log.info(f"Ray test connection: Submitted job with ID: {job_id}") job_completed = False @@ -97,52 +136,12 @@ def test_connection(self): job_completed = True connection_attempt -= 1 - if job_status != JobStatus.SUCCEEDED: return False, f"Ray test connection failed: Job {job_id} status {job_status}" return True, job_status # TODO: check webserver logs - @classmethod - def get_ui_field_behaviour(cls) -> dict[str, Any]: - """ - Return custom field behaviour for the connection form. - - :return: A dictionary specifying custom field behaviour. - """ - return { - "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], - "relabeling": {}, - } - - @classmethod - def get_connection_form_widgets(cls) -> dict[str, Any]: - """ - Return connection widgets to add to connection form. - - :return: A dictionary of connection form widgets. - """ - from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget - from flask_babel import lazy_gettext - from wtforms import BooleanField, PasswordField, StringField - - return { - "address": StringField(lazy_gettext("Ray dashboard url"), widget=BS3TextFieldWidget()), - # "create_cluster_if_needed": BooleanField(lazy_gettext("Create cluster if needed")), - "cookies": StringField(lazy_gettext("Cookies"), widget=BS3TextFieldWidget()), - "metadata": StringField(lazy_gettext("Metadata"), widget=BS3TextFieldWidget()), - "headers": StringField(lazy_gettext("Headers"), widget=BS3TextFieldWidget()), - "verify": BooleanField(lazy_gettext("Verify")), - "kube_config_path": StringField(lazy_gettext("Kube config path"), widget=BS3TextFieldWidget()), - "kube_config": PasswordField(lazy_gettext("Kube config (JSON format)"), widget=BS3PasswordFieldWidget()), - "namespace": StringField(lazy_gettext("Namespace"), widget=BS3TextFieldWidget()), - "cluster_context": StringField(lazy_gettext("Cluster context"), widget=BS3TextFieldWidget()), - "disable_verify_ssl": BooleanField(lazy_gettext("Disable SSL")), - "disable_tcp_keepalive": BooleanField(lazy_gettext("Disable TCP keepalive")), - } - - def _setup_kubeconfig( self, kubeconfig_path: str | None, kubeconfig_content: str | None, cluster_context: str | None ) -> None: @@ -187,19 +186,16 @@ def ray_client(self, dashboard_url: str | None = None) -> JobSubmissionClient: :raises AirflowException: If the connection fails. """ if not self.ray_client_instance: - try: - self.log.info(f"Address URL is: {self.address}") - self.log.info(f"Dashboard URL is: {dashboard_url}") - self.ray_client_instance = JobSubmissionClient( - address=dashboard_url or self.address, - create_cluster_if_needed=self.create_cluster_if_needed, - cookies=self.cookies, - metadata=self.metadata, - headers=self.headers, - verify=self.verify, - ) - except Exception as e: - raise AirflowException(f"Failed to create Ray JobSubmissionClient: {e}") + self.log.info(f"Address URL is: {self.address}") + self.log.info(f"Dashboard URL is: {dashboard_url}") + self.ray_client_instance = JobSubmissionClient( + address=dashboard_url or self.address, + create_cluster_if_needed=self.create_cluster_if_needed, + cookies=self.cookies, + metadata=self.metadata, + headers=self.headers, + verify=self.verify, + ) return self.ray_client_instance def submit_ray_job( @@ -341,7 +337,6 @@ def _check_load_balancer_readiness(self, lb_details: dict[str, Any]) -> str | No ip: str | None = lb_details["ip"] hostname: str | None = lb_details["hostname"] - self.log.info(f"ports: {lb_details['ports']}") for port_info in lb_details["ports"]: port = port_info["port"] if ip and self._is_port_open(ip, port): @@ -351,123 +346,6 @@ def _check_load_balancer_readiness(self, lb_details: dict[str, Any]) -> str | No return None - def _get_node_ip(self) -> str: - """ - Retrieve the IP address of a Kubernetes node. - - :return: The IP address of a node in the Kubernetes cluster. - """ - # Example: Retrieve the first node's IP (adjust based on your cluster setup) - nodes = self.core_v1_client.list_node().items - self.log.info(f"Nodes: {nodes}") - for node in nodes: - self.log.info(f"Node address: {node.status.addresses}") - for address in node.status.addresses: - if address.type == "ExternalIP": - return address.address - - for node in nodes: - self.log.info(f"Node address: {node.status.addresses}") - for address in node.status.addresses: - if address.type == "InternalIP": - return address.address - - raise AirflowException("No valid node IP found in the cluster.") - - def _setup_node_port(self, name: str, namespace: str, context: dict) -> None: - """ - Set up the NodePort service and push URLs to XCom. - - :param name: The name of the Ray cluster. - :param namespace: The namespace where the cluster is deployed. - :param context: The Airflow task context. - """ - node_port_details: dict[str, Any] = self._wait_for_node_port_service( - service_name=f"{name}-head-svc", namespace=namespace - ) - - if node_port_details: - self.log.info(node_port_details) - - node_ports = node_port_details["node_ports"] - # Example: Assuming `node_ip` is provided as an environment variable or a known cluster node. - node_ip = self._get_node_ip() # Implement this method to return a valid node IP or DNS. - - for port in node_ports: - url = f"http://{node_ip}:{port['port']}" - context["task_instance"].xcom_push(key=port["name"], value=url) - self.log.info(f"Pushed URL to XCom: {url}") - else: - self.log.info("No NodePort URLs to push to XCom.") - - def _wait_for_node_port_service( - self, - service_name: str, - namespace: str = "default", - max_retries: int = 30, - retry_interval: int = 10, - ) -> dict[str, Any]: - """ - Wait for the NodePort service to be ready and return its details. - - :param service_name: The name of the NodePort service. - :param namespace: The namespace of the service. - :param max_retries: Maximum number of retries. - :param retry_interval: Interval between retries in seconds. - :return: A dictionary containing NodePort service details. - :raises AirflowException: If the service does not become ready within the specified retries. - """ - for attempt in range(1, max_retries + 1): - self.log.info(f"Attempt {attempt}: Checking NodePort service status...") - - try: - service: client.V1Service = self._get_service(service_name, namespace) - service_details: dict[str, Any] | None = self._get_node_port_details(service) - - if service_details: - self.log.info("NodePort service is ready.") - return service_details - - self.log.info("NodePort details not available yet. Retrying...") - except AirflowException: - self.log.info("Service is not available yet.") - - time.sleep(retry_interval) - - raise AirflowException(f"Service did not become ready after {max_retries} attempts") - - def _get_node_port_details(self, service: client.V1Service) -> dict[str, Any] | None: - """ - Extract NodePort details from the service. - - :param service: The Kubernetes service object. - :return: A dictionary containing NodePort details if available, None otherwise. - """ - node_ports = [] - for port in service.spec.ports: - if port.node_port: - node_ports.append({"name": port.name, "port": port.node_port}) - - if node_ports: - return {"node_ports": node_ports} - - return None - - def _check_node_port_connectivity(self, node_ports: list[dict[str, Any]]) -> bool: - """ - Check if the NodePort is reachable. - - :param node_ports: List of NodePort details. - :return: True if at least one NodePort is accessible, False otherwise. - """ - for port_info in node_ports: - # Replace with actual logic to test connectivity if needed. - self.log.info(f"Checking connectivity for NodePort {port_info['port']}") - # Example: Simulate readiness check. - if self._is_port_open("example-node-ip", port_info["port"]): - return True - return False - def _wait_for_load_balancer( self, service_name: str, @@ -515,41 +393,6 @@ def _wait_for_load_balancer( raise AirflowException(f"LoadBalancer did not become ready after {max_retries} attempts") - def _get_load_balancer_details(self, service: client.V1Service) -> dict[str, Any] | None: - """ - Extract LoadBalancer details from the service. - - :param service: The Kubernetes service object. - :return: A dictionary containing LoadBalancer details if available, None otherwise. - """ - if service.status.load_balancer.ingress: - ingress: client.V1LoadBalancerIngress = service.status.load_balancer.ingress[0] - ip: str | None = ingress.ip - hostname: str | None = ingress.hostname - if ip or hostname: - ports: list[dict[str, Any]] = [{"name": port.name, "port": port.port} for port in service.spec.ports] - return {"ip": ip, "hostname": hostname, "ports": ports} - return None - - def _check_load_balancer_readiness(self, lb_details: dict[str, Any]) -> str | None: - """ - Check if the LoadBalancer is ready by testing port connectivity. - - :param lb_details: Dictionary containing LoadBalancer details. - :return: The working address (IP or hostname) if ready, None otherwise. - """ - ip: str | None = lb_details["ip"] - hostname: str | None = lb_details["hostname"] - - for port_info in lb_details["ports"]: - port = port_info["port"] - if ip and self._is_port_open(ip, port): - return ip - if hostname and self._is_port_open(hostname, port): - return hostname - - return None - def _validate_yaml_file(self, yaml_file: str) -> None: """ Validate the existence and format of the YAML file. @@ -590,16 +433,13 @@ def _create_or_update_cluster( :param cluster_spec: The specification of the Ray cluster. :raises AirflowException: If there's an error accessing or creating the Ray cluster. """ - """self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) if update_if_exists: + self.log.info(f"Updating existing Ray cluster: {name}") + self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) self.custom_object_client.patch_namespaced_custom_object( group=group, version=version, namespace=namespace, plural=plural, name=name, body=cluster_spec ) - - except client.exceptions.ApiException as e: - if e.status == 404: - """ self.log.info(f"Creating new Ray cluster: {name}") @@ -608,13 +448,16 @@ def _create_or_update_cluster( ) self.log.info(f"Resource created. Response: {response}") + # TODO: may go to a different PR start_time = time.time() wait_timeout = 300 poll_interval = 5 while time.time() - start_time < wait_timeout: try: - resource = self.get_custom_object(group=group, version=version, plural=plural, name=name, namespace=namespace) + resource = self.get_custom_object( + group=group, version=version, plural=plural, name=name, namespace=namespace + ) except client.exceptions.ApiException as e: self.log.warning(f"Error fetching resource status: {e}") else: @@ -626,12 +469,9 @@ def _create_or_update_cluster( time.sleep(poll_interval) - raise TimeoutError(f"Resource {name} of group {group} did not reach the desired state within {wait_timeout} seconds.") - - """ - else: - raise AirflowException(f"Error accessing Ray cluster '{name}': {e}") - """ + raise TimeoutError( + f"Resource {name} of group {group} did not reach the desired state within {wait_timeout} seconds." + ) def _setup_gpu_driver(self, gpu_device_plugin_yaml: str) -> None: """ @@ -685,7 +525,6 @@ def setup_ray_cluster( :param update_if_exists: Whether to update the cluster if it already exists. :raises AirflowException: If there's an error setting up the Ray cluster. """ - #try: self._validate_yaml_file(ray_cluster_yaml) self.ray_cluster_yaml = ray_cluster_yaml @@ -716,24 +555,20 @@ def setup_ray_cluster( ) except TimeoutError as e: self._delete_ray_cluster_crd(ray_cluster_yaml) - raise AirflowException(e) + raise e self.log.info("::endgroup::") - #self._setup_gpu_driver(gpu_device_plugin_yaml=gpu_device_plugin_yaml) + self._setup_gpu_driver(gpu_device_plugin_yaml=gpu_device_plugin_yaml) - #self.log.info("::group:: (Step 3/3) Setup Node Port service") - #self._setup_node_port(name, namespace, context) - #self.log.info("::endgroup::") + # TODO: separate PR + # self.log.info("::group:: (Step 3/3) Setup Node Port service") + # self._setup_node_port(name, namespace, context) + # self.log.info("::endgroup::") self.log.info("::group:: (Setup 3/3) Setup Load Balancer service") self._setup_load_balancer(name, namespace, context) self.log.info("::endgroup::") - #except Exception as e: - # self.log.error(f"Error setting up Ray cluster: {e}") - # raise AirflowException(f"Failed to set up Ray cluster: {e}") - - def _delete_ray_cluster_crd(self, ray_cluster_yaml: str) -> None: """ Delete the Ray cluster based on the cluster specification. @@ -763,7 +598,6 @@ def _delete_ray_cluster_crd(self, ray_cluster_yaml: str) -> None: self.delete_custom_object(group=group, version=version, name=name, namespace=namespace, plural=plural) self.log.info(f"Deleted Ray cluster: {name}") - def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) -> None: """ Execute the operator to delete the Ray cluster. @@ -772,11 +606,10 @@ def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) :param gpu_device_plugin_yaml: Path or URL to the GPU device plugin YAML. Defaults to NVIDIA's plugin :raises AirflowException: If there's an error deleting the Ray cluster. """ - #try: self._validate_yaml_file(ray_cluster_yaml) if gpu_device_plugin_yaml: - #Delete the NVIDIA GPU device plugin DaemonSet if it exists. + # Delete the NVIDIA GPU device plugin DaemonSet if it exists. gpu_driver = self.load_yaml_content(gpu_device_plugin_yaml) gpu_driver_name = gpu_driver["metadata"]["name"] @@ -793,7 +626,7 @@ def delete_ray_cluster(self, ray_cluster_yaml: str, gpu_device_plugin_yaml: str) self.log.info("::group:: Delete Kuberay operator") self.uninstall_kuberay_operator() self.log.info("::endgroup::") - #except Exception as e: + # except Exception as e: # self.log.error(f"Error deleting Ray cluster: {e}") # raise AirflowException(f"Failed to delete Ray cluster: {e}") @@ -886,7 +719,6 @@ def create_daemon_set(self, name: str, body: dict[str, Any]) -> client.V1DaemonS :param body: The body of the DaemonSet for the create action. :return: The created DaemonSet resource if successful, None otherwise. """ - self.log.warning("Trying to create create_daemon_set %s", name) if not body: self.log.error("Body must be provided for create action.") return None @@ -906,7 +738,6 @@ def delete_daemon_set(self, name: str) -> client.V1Status | None: :param name: The name of the DaemonSet. :return: The status of the delete operation if successful, None otherwise. """ - self.log.info("Trying to delete_daemon_set %s", name) try: delete_response = self.apps_v1_client.delete_namespaced_daemon_set(name=name, namespace=self.namespace) self.log.info(f"DaemonSet {name} deleted.") @@ -914,3 +745,125 @@ def delete_daemon_set(self, name: str) -> client.V1Status | None: except client.exceptions.ApiException as e: self.log.error(f"Exception when deleting DaemonSet: {e}") return None + + # Add this to yet another PR + def _get_node_ip(self) -> str: + """ + Retrieve the IP address of a Kubernetes node. + + :return: The IP address of a node in the Kubernetes cluster. + """ + # Example: Retrieve the first node's IP (adjust based on your cluster setup) + nodes = self.core_v1_client.list_node().items + self.log.info(f"Nodes: {nodes}") + for node in nodes: + self.log.info(f"Node address: {node.status.addresses}") + for address in node.status.addresses: + if address.type == "ExternalIP": + return address.address + + for node in nodes: + self.log.info(f"Node address: {node.status.addresses}") + for address in node.status.addresses: + if address.type == "InternalIP": + return address.address + + raise AirflowException("No valid node IP found in the cluster.") + + # Add this to yet another PR + def _setup_node_port(self, name: str, namespace: str, context: dict) -> None: + """ + Set up the NodePort service and push URLs to XCom. + + :param name: The name of the Ray cluster. + :param namespace: The namespace where the cluster is deployed. + :param context: The Airflow task context. + """ + node_port_details: dict[str, Any] = self._wait_for_node_port_service( + service_name=f"{name}-head-svc", namespace=namespace + ) + + if node_port_details: + self.log.info(node_port_details) + + node_ports = node_port_details["node_ports"] + # Example: Assuming `node_ip` is provided as an environment variable or a known cluster node. + node_ip = self._get_node_ip() # Implement this method to return a valid node IP or DNS. + + for port in node_ports: + url = f"http://{node_ip}:{port['port']}" + context["task_instance"].xcom_push(key=port["name"], value=url) + self.log.info(f"Pushed URL to XCom: {url}") + else: + self.log.info("No NodePort URLs to push to XCom.") + + # Add this to yet another PR + def _wait_for_node_port_service( + self, + service_name: str, + namespace: str = "default", + max_retries: int = 30, + retry_interval: int = 10, + ) -> dict[str, Any]: + """ + Wait for the NodePort service to be ready and return its details. + + :param service_name: The name of the NodePort service. + :param namespace: The namespace of the service. + :param max_retries: Maximum number of retries. + :param retry_interval: Interval between retries in seconds. + :return: A dictionary containing NodePort service details. + :raises AirflowException: If the service does not become ready within the specified retries. + """ + for attempt in range(1, max_retries + 1): + self.log.info(f"Attempt {attempt}: Checking NodePort service status...") + + try: + service: client.V1Service = self._get_service(service_name, namespace) + service_details: dict[str, Any] | None = self._get_node_port_details(service) + + if service_details: + self.log.info("NodePort service is ready.") + return service_details + + self.log.info("NodePort details not available yet. Retrying...") + except AirflowException: + self.log.info("Service is not available yet.") + + time.sleep(retry_interval) + + raise AirflowException(f"Service did not become ready after {max_retries} attempts") + + # Add this to yet another PR + def _get_node_port_details(self, service: client.V1Service) -> dict[str, Any] | None: + """ + Extract NodePort details from the service. + + :param service: The Kubernetes service object. + :return: A dictionary containing NodePort details if available, None otherwise. + """ + node_ports = [] + for port in service.spec.ports: + if port.node_port: + node_ports.append({"name": port.name, "port": port.node_port}) + + if node_ports: + return {"node_ports": node_ports} + + return None + + # Add this to yet another PR + def _check_node_port_connectivity(self, node_ports: list[dict[str, Any]]) -> bool: + """ + Check if the NodePort is reachable. + + :param node_ports: List of NodePort details. + :return: True if at least one NodePort is accessible, False otherwise. + """ + for port_info in node_ports: + # Replace with actual logic to test connectivity if needed. + self.log.info(f"Checking connectivity for NodePort {port_info['port']}") + # Example: Simulate readiness check. + if self._is_port_open("example-node-ip", port_info["port"]): + return True + return False From 206e57d2943f9ac181e31b6fab0562245a5e2bca Mon Sep 17 00:00:00 2001 From: Tatiana Al-Chueyr Date: Fri, 29 Nov 2024 08:26:53 +0000 Subject: [PATCH 6/6] Improve type checks --- ray_provider/hooks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ray_provider/hooks.py b/ray_provider/hooks.py index 560a325..6f30884 100644 --- a/ray_provider/hooks.py +++ b/ray_provider/hooks.py @@ -113,14 +113,14 @@ def __init__( # Create a PR for this @cached_property - def namespace(self): + def namespace(self) -> str: if self.ray_cluster_yaml is None: return self.default_namespace cluster_spec = self.load_yaml_content(self.ray_cluster_yaml) return cluster_spec["metadata"].get("namespace") or self.default_namespace # Create another PR for this - def test_connection(self): + def test_connection(self) -> (bool, str): job_client = self.ray_client(self.address) job_id = job_client.submit_job(entrypoint="import ray; ray.init(); print(ray.cluster_resources())") @@ -420,7 +420,7 @@ def _create_or_update_cluster( name: str, namespace: str, cluster_spec: dict[str, Any], - ) -> None: + ) -> str: """ Create or update the Ray cluster based on the cluster specification.