Skip to content

Commit

Permalink
Support using callable config in @ray.task (#103)
Browse files Browse the repository at this point in the history
The Ray provider 0.2.1 allowed users to define a hard-coded configuration to materialize the Kubernetes cluster. This PR aims to enable users to define a function that can receive the Airflow context and generate the configuration dynamically using context properties. This request came from an Astronomer customer.

There is an example DAG file illustrating how to use this feature. It has a parent DAG that triggers two child DAGs, which leverage the just introduced `@ray.task` callable configuration.

The screenshots below show their success, when using the [local development instructions](https://github.com/astronomer/astro-provider-ray/blob/main/docs/getting_started/local_development_setup.rst) using Astro CLI.

Parent DAG:
<img width="1624" alt="Screenshot 2024-11-29 at 12 15 13" src="https://github.com/user-attachments/assets/586b4575-ee62-4344-bbd7-a1a6423360ce">

Child 1 DAG:
<img width="1624" alt="Screenshot 2024-11-29 at 12 15 56" src="https://github.com/user-attachments/assets/23d89288-c68a-498e-848a-743fb2684c4f">

Example of logs that illustrate the RayCluster using dynamic configuration was created and used in Kubernetes, with its own IP address:
```
(...)
[2024-11-29T12:14:52.276+0000] {standard_task_runner.py:104} INFO - Running: ['airflow', 'tasks', 'run', 'ray_dynamic_config_child_1', 'process_data_with_ray', 'manual__2024-11-29T12:14:50.273712+00:00', '--job-id', '773', '--raw', '--subdir', 'DAGS_FOLDER/ray_dynamic_config.py', '--cfg-path', '/tmp/tmpkggwlv23']
[2024-11-29T12:14:52.278+0000] {logging_mixin.py:190} WARNING - /usr/local/lib/python3.12/site-packages/airflow/task/task_runner/standard_task_runner.py:70 DeprecationWarning: This process (pid=238) is multi-threaded, use of fork() may lead to deadlocks in the child.
(...)
[2024-11-29T12:14:52.745+0000] {decorators.py:94} INFO - Using the following config {'conn_id': 'ray_conn', 'runtime_env': {'working_dir': '/usr/local/airflow/dags/ray_scripts', 'pip': ['numpy']}, 'num_cpus': 1, 'num_gpus': 0, 'memory': 0, 'poll_interval': 5, 'ray_cluster_yaml': '/usr/local/airflow/dags/scripts/first-254.yaml', 'xcom_task_key': 'dashboard'}
(...)
[2024-11-29T12:14:55.430+0000] {hooks.py:474} INFO - ::group::Create Ray Cluster
[2024-11-29T12:14:55.430+0000] {hooks.py:475} INFO - Loading yaml content for Ray cluster CRD...
[2024-11-29T12:14:55.451+0000] {hooks.py:410} INFO - Creating new Ray cluster: first-254
[2024-11-29T12:14:55.456+0000] {hooks.py:494} INFO - ::endgroup::
(...)
[2024-11-29T12:14:55.663+0000] {hooks.py:498} INFO - ::group::Setup Load Balancer service
[2024-11-29T12:14:55.663+0000] {hooks.py:334} INFO - Attempt 1: Checking LoadBalancer status...
[2024-11-29T12:14:55.669+0000] {hooks.py:278} ERROR - Error getting service first-254-head-svc: (404)
Reason: Not Found
HTTP response headers: HTTPHeaderDict({'Audit-Id': '81b07ac4-db3b-48a6-b336-f52ae93bee55', 'Cache-Control': 'no-cache, private', 'Content-Type': 'application/json', 'X-Kubernetes-Pf-Flowschema-Uid': '955e8bb0-08b1-4d45-a768-e49387a9767c', 'X-Kubernetes-Pf-Prioritylevel-Uid': 'd5240328-288d-4366-b094-d8fd793c7431', 'Date': 'Fri, 29 Nov 2024 12:14:55 GMT', 'Content-Length': '212'})
HTTP response body: {"kind":"Status","apiVersion":"v1","metadata":{},"status":"Failure","message":"services \"first-254-head-svc\" not found","reason":"NotFound","details":{"name":"first-254-head-svc","kind":"services"},"code":404}
[2024-11-29T12:14:55.669+0000] {hooks.py:355} INFO - LoadBalancer service is not available yet...
[2024-11-29T12:15:35.670+0000] {hooks.py:334} INFO - Attempt 2: Checking LoadBalancer status...
[2024-11-29T12:15:35.688+0000] {hooks.py:348} INFO - LoadBalancer is ready.
[2024-11-29T12:15:35.688+0000] {hooks.py:441} INFO - {'ip': '172.18.255.1', 'hostname': None, 'ports': [{'name': 'client', 'port': 10001}, {'name': 'dashboard', 'port': 8265}, {'name': 'gcs', 'port': 6379}, {'name': 'metrics', 'port': 8080}, {'name': 'serve', 'port': 8000}], 'working_address': '172.18.255.1'}

(...)

[2024-11-29T12:15:38.345+0000] {triggers.py:124} INFO - ::group:: Trigger 1/2: Checking the job status
[2024-11-29T12:15:38.345+0000] {triggers.py:125} INFO - Polling for job raysubmit_paxAkyLiKxEHPmwG every 5 seconds...
(...)
[2024-11-29T12:15:38.354+0000] {hooks.py:156} INFO - Dashboard URL is: http://172.18.255.1:8265
[2024-11-29T12:15:38.361+0000] {hooks.py:208} INFO - Job raysubmit_paxAkyLiKxEHPmwG status: PENDING
[2024-11-29T12:15:38.361+0000] {triggers.py:100} INFO - Status of job raysubmit_paxAkyLiKxEHPmwG is: PENDING
[2024-11-29T12:15:38.361+0000] {triggers.py:108} INFO - ::group::raysubmit_paxAkyLiKxEHPmwG logs
[2024-11-29T12:15:43.416+0000] {hooks.py:208} INFO - Job raysubmit_paxAkyLiKxEHPmwG status: RUNNING
[2024-11-29T12:15:43.416+0000] {triggers.py:100} INFO - Status of job raysubmit_paxAkyLiKxEHPmwG is: RUNNING
[2024-11-29T12:15:43.417+0000] {triggers.py:112} INFO - 2024-11-29 04:15:40,813	INFO worker.py:1429 -- Using address 10.244.0.140:6379 set in the environment variable RAY_ADDRESS
[2024-11-29T12:15:43.417+0000] {triggers.py:112} INFO - 2024-11-29 04:15:40,814	INFO worker.py:1564 -- Connecting to existing Ray cluster at address: 10.244.0.140:6379...
[2024-11-29T12:15:43.417+0000] {triggers.py:112} INFO - 2024-11-29 04:15:40,820	INFO worker.py:1740 -- Connected to Ray cluster. View the dashboard at �[1m�[32m10.244.0.140:8265 �[39m�[22m
[2024-11-29T12:15:48.430+0000] {hooks.py:208} INFO - Job raysubmit_paxAkyLiKxEHPmwG status: SUCCEEDED
[2024-11-29T12:15:48.430+0000] {triggers.py:112} INFO - Mean of this population is 12.0
[2024-11-29T12:15:48.430+0000] {triggers.py:112} INFO - �[36m(autoscaler +5s)�[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.
[2024-11-29T12:15:48.430+0000] {triggers.py:112} INFO - �[36m(autoscaler +5s)�[0m Adding 1 node(s) of type small-group.
[2024-11-29T12:15:49.448+0000] {triggers.py:113} INFO - ::endgroup::
[2024-11-29T12:15:49.448+0000] {triggers.py:144} INFO - ::endgroup::
[2024-11-29T12:15:49.448+0000] {triggers.py:145} INFO - ::group:: Trigger 2/2: Job reached a terminal state
[2024-11-29T12:15:49.448+0000] {triggers.py:146} INFO - Status of completed job raysubmit_paxAkyLiKxEHPmwG is: SUCCEEDED
(...)
```

Child 2 DAG:
<img width="1624" alt="Screenshot 2024-11-29 at 12 17 20" src="https://github.com/user-attachments/assets/5f0320a2-3bce-49a9-8580-a584d1f894dc">

Kubernetes RayClusters spun:
<img width="758" alt="Screenshot 2024-11-29 at 12 15 37" src="https://github.com/user-attachments/assets/aabcce4a-ce4a-47db-b9cf-e87cf68f2316">


**Limitations**

The example DAGs are not currently being executed in the CI, but there is a dedicated ticket for this work:
#95

**References**

This PR had inspiration from:
#67
  • Loading branch information
tatiana authored Nov 29, 2024
1 parent 5b15f3a commit 56387cc
Show file tree
Hide file tree
Showing 3 changed files with 309 additions and 51 deletions.
197 changes: 197 additions & 0 deletions dev/dags/ray_dynamic_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
This example illustrates three DAGs. One
The parent DAG (ray_dynamic_config_upstream_dag) uses TriggerDagRunOperator to trigger the other two:
* ray_dynamic_config_downstream_dag_1
* ray_dynamic_config_downstream_dag_2
Each downstream DAG retrieves the context data (run_context) from dag_run.conf, which is passed by the parent DAG.
The print_context tasks in the downstream DAGs output the received context to the logs.
"""

import re
from pathlib import Path

import yaml
from airflow import DAG
from airflow.decorators import task
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.utils.dates import days_ago
from jinja2 import Template

from ray_provider.decorators import ray

CONN_ID = "ray_conn"
RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml"
FOLDER_PATH = Path(__file__).parent / "ray_scripts"
RAY_TASK_CONFIG = {
"conn_id": CONN_ID,
"runtime_env": {"working_dir": str(FOLDER_PATH), "pip": ["numpy"]},
"num_cpus": 1,
"num_gpus": 0,
"memory": 0,
"poll_interval": 5,
"ray_cluster_yaml": str(RAY_SPEC),
"xcom_task_key": "dashboard",
}


def slugify(value):
"""
Replace invalid characters with hyphens and make lowercase.
"""
return re.sub(r"[^\w\-\.]", "-", value).lower()


def create_config_from_context(context, **kwargs):
default_name = "{{ dag.dag_id }}-{{ dag_run.id }}"

raycluster_name_template = context.get("dag_run").conf.get("raycluster_name", default_name)
raycluster_name = Template(raycluster_name_template).render(context).replace("_", "-")
raycluster_name = slugify(raycluster_name)

raycluster_k8s_yml_filename_template = context.get("dag_run").conf.get(
"raycluster_k8s_yml_filename", default_name + ".yml"
)
raycluster_k8s_yml_filename = Template(raycluster_k8s_yml_filename_template).render(context).replace("_", "-")
raycluster_k8s_yml_filename = slugify(raycluster_k8s_yml_filename)

with open(RAY_SPEC) as file:
data = yaml.safe_load(file)
data["metadata"]["name"] = raycluster_name

NEW_RAY_K8S_SPEC = Path(__file__).parent / "scripts" / raycluster_k8s_yml_filename
with open(NEW_RAY_K8S_SPEC, "w") as file:
yaml.safe_dump(data, file, default_flow_style=False)

config = dict(RAY_TASK_CONFIG)
config["ray_cluster_yaml"] = str(NEW_RAY_K8S_SPEC)
return config


def print_context(**kwargs):
# Retrieve `conf` passed from the parent DAG
print(kwargs)
cluster_name = kwargs.get("dag_run").conf.get("raycluster_name", "No ray cluster name provided")
raycluster_k8s_yml_filename = kwargs.get("dag_run").conf.get(
"raycluster_k8s_yml_filename", "No ray cluster YML filename provided"
)
print(f"Received cluster name: {cluster_name}")
print(f"Received cluster K8s YML filename: {raycluster_k8s_yml_filename}")


# Downstream 1
with DAG(
dag_id="ray_dynamic_config_child_1",
start_date=days_ago(1),
schedule_interval=None,
catchup=False,
) as dag:

print_context_task = PythonOperator(
task_id="print_context",
python_callable=print_context,
)
print_context_task

@task
def generate_data():
return [1, 2, 3]

@ray.task(config=create_config_from_context)
def process_data_with_ray(data):
import numpy as np
import ray

@ray.remote
def cubic(x):
return x**3

ray.init()
data = np.array(data)
futures = [cubic.remote(x) for x in data]
results = ray.get(futures)
mean = np.mean(results)
print(f"Mean of this population is {mean}")
return mean

data = generate_data()
process_data_with_ray(data)


# Downstream 2
with DAG(
dag_id="ray_dynamic_config_child_2",
start_date=days_ago(1),
schedule_interval=None,
catchup=False,
) as dag:

print_context_task = PythonOperator(
task_id="print_context",
python_callable=print_context,
)

@task
def generate_data():
return [1, 2, 3]

@ray.task(config=create_config_from_context)
def process_data_with_ray(data):
import numpy as np
import ray

@ray.remote
def square(x):
return x**2

ray.init()
data = np.array(data)
futures = [square.remote(x) for x in data]
results = ray.get(futures)
mean = np.mean(results)
print(f"Mean of this population is {mean}")
return mean

data = generate_data()
process_data_with_ray(data)


# Upstream
with DAG(
dag_id="ray_dynamic_config_parent",
start_date=days_ago(1),
schedule_interval=None,
catchup=False,
) as dag:
empty_task = EmptyOperator(task_id="empty_task")

trigger_dag_1 = TriggerDagRunOperator(
task_id="trigger_downstream_dag_1",
trigger_dag_id="ray_dynamic_config_child_1",
conf={
"raycluster_name": "first-{{ dag_run.id }}",
"raycluster_k8s_yml_filename": "first-{{ dag_run.id }}.yaml",
},
)

trigger_dag_2 = TriggerDagRunOperator(
task_id="trigger_downstream_dag_2",
trigger_dag_id="ray_dynamic_config_child_2",
conf={},
)

# Illustrates that by default two DAG runs of the same DAG will be using different Ray clusters
# Disabled because in the local dev MacOS we're only managing to spin up two Ray Cluster services concurrently
# trigger_dag_3 = TriggerDagRunOperator(
# task_id="trigger_downstream_dag_3",
# trigger_dag_id="ray_dynamic_config_child_2",
# conf={},
# )

empty_task >> trigger_dag_1
trigger_dag_1 >> trigger_dag_2
# trigger_dag_1 >> trigger_dag_3
101 changes: 57 additions & 44 deletions ray_provider/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import inspect
import os
import re
import shutil
import tempfile
import textwrap
from tempfile import mkdtemp
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.exceptions import AirflowException
from airflow.utils.context import Context

from ray_provider.exceptions import RayAirflowException
from ray_provider.operators import SubmitRayJob


Expand All @@ -28,20 +29,37 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob):
"""

custom_operator_name = "@task.ray"
_config: dict[str, Any] | Callable[..., dict[str, Any]] = dict()

template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs")

def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None:
self._config = config
self.kwargs = kwargs
super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs)

def _build_config(self, context: Context) -> dict[str, Any]:
if callable(self._config):
config_params = inspect.signature(self._config).parameters
config_kwargs = {k: v for k, v in self.kwargs.items() if k in config_params and k != "context"}
if "context" in config_params:
config_kwargs["context"] = context
config = self._config(**config_kwargs)
assert isinstance(config, dict)
return config
return self._config

def _load_config(self, config: dict[str, Any]) -> None:
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory", None)
self.ray_resources: dict[str, Any] | None = config.get("resources", None)
self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None)
self.memory: int | float = config.get("memory", 1)
self.ray_resources: dict[str, Any] | None = config.get("resources")
self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml")
self.update_if_exists: bool = config.get("update_if_exists", False)
self.kuberay_version: str = config.get("kuberay_version", "1.0.0")
self.gpu_device_plugin_yaml: str = config.get(
Expand All @@ -50,35 +68,19 @@ def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
job_timeout_seconds = config.get("job_timeout_seconds", 600)
self.job_timeout_seconds: timedelta | None = (
timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None
)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key", None)
self.xcom_task_key: str | None = config.get("xcom_task_key")

self.config = config

if not isinstance(self.num_cpus, (int, float)):
raise TypeError("num_cpus should be an integer or float value")
raise RayAirflowException("num_cpus should be an integer or float value")
if not isinstance(self.num_gpus, (int, float)):
raise TypeError("num_gpus should be an integer or float value")

super().__init__(
conn_id=self.conn_id,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
num_cpus=self.num_cpus,
num_gpus=self.num_gpus,
memory=self.memory,
resources=self.ray_resources,
ray_cluster_yaml=self.ray_cluster_yaml,
update_if_exists=self.update_if_exists,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
fetch_logs=self.fetch_logs,
wait_for_completion=self.wait_for_completion,
job_timeout_seconds=job_timeout_seconds,
poll_interval=self.poll_interval,
xcom_task_key=self.xcom_task_key,
**kwargs,
)
raise RayAirflowException("num_gpus should be an integer or float value")

def execute(self, context: Context) -> Any:
"""
Expand All @@ -88,28 +90,29 @@ def execute(self, context: Context) -> Any:
:return: The result of the Ray job execution.
:raises AirflowException: If job submission fails.
"""
tmp_dir = None
try:
config = self._build_config(context)
self.log.info(f"Using the following config {config}")
self._load_config(config)

with tempfile.TemporaryDirectory(prefix="ray_") as tmpdirname:
temp_dir = Path(tmpdirname)

if self.is_decorated_function:
self.log.info(
f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}"
)
# Create a temporary directory that won't be immediately deleted
temp_dir = mkdtemp(prefix="ray_")
script_filename = os.path.join(temp_dir, "script.py")

# Get the Python source code and extract just the function body
full_source = inspect.getsource(self.python_callable)
function_body = self._extract_function_body(full_source)
if not function_body:
raise ValueError("Failed to retrieve Python source code")

# Prepare the function call
args_str = ", ".join(repr(arg) for arg in self.op_args)
kwargs_str = ", ".join(f"{k}={repr(v)}" for k, v in self.op_kwargs.items())
call_str = f"{self.python_callable.__name__}({args_str}, {kwargs_str})"

# Write the script with function definition and call
script_filename = os.path.join(temp_dir, "script.py")
with open(script_filename, "w") as file:
file.write(function_body)
file.write(f"\n\n# Execute the function\n{call_str}\n")
Expand All @@ -122,21 +125,27 @@ def execute(self, context: Context) -> Any:
result = super().execute(context)

return result
except Exception as e:
self.log.error(f"Failed during execution with error: {e}")
raise AirflowException("Job submission failed") from e
finally:
if tmp_dir and os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)

def _extract_function_body(self, source: str) -> str:
"""Extract the function, excluding only the ray.task decorator."""
self.log.info(r"Ray pipeline intended to be executed: \n %s", source)
if "@ray.task" not in source:
raise RayAirflowException("Unable to parse this body. Expects the `@ray.task` decorator.")
lines = source.split("\n")
# TODO: Review the current approach, that is quite hacky.
# It feels a mistake to have a user-facing module named the same as the official ray SDK.
# In particular, the decorator is working in a very artificial way, where ray means two different things
# at the scope of the task definition (Astro Ray Provider decorator) and inside the decorated method (Ray SDK)
# Find the line where the ray.task decorator is
# Additionally, if users imported the ray decorator as "from ray_provider.decorators import ray as ray_decorator
# The following will stop working.
ray_task_line = next((i for i, line in enumerate(lines) if re.match(r"^\s*@ray\.task", line.strip())), -1)

# Include everything except the ray.task decorator line
body = "\n".join(lines[:ray_task_line] + lines[ray_task_line + 1 :])

if not body:
raise RayAirflowException("Failed to extract Ray pipeline code decorated with @ray.task")
# Dedent the body
return textwrap.dedent(body)

Expand All @@ -146,19 +155,23 @@ class ray:
def task(
python_callable: Callable[..., Any] | None = None,
multiple_outputs: bool | None = None,
config: dict[str, Any] | Callable[[], dict[str, Any]] | None = None,
**kwargs: Any,
) -> TaskDecorator:
"""
Decorator to define a task that submits a Ray job.
:param python_callable: The callable function to decorate.
:param multiple_outputs: If True, will return multiple outputs.
:param config: A dictionary of configuration or a callable that returns a dictionary.
:param kwargs: Additional keyword arguments.
:return: The decorated task.
"""
config = config or {}
return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_RayDecoratedOperator,
config=config,
**kwargs,
)
Loading

0 comments on commit 56387cc

Please sign in to comment.