Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decorator configuration improvements #67

Closed
wants to merge 18 commits into from
Closed
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Check out the Getting Started guide in our [docs](https://astronomer.github.io/a
## Sample DAGs

### Example 1: Using @ray.task for job life cycle
The below example showcases how to use the ``@ray.task`` decorator to manage the full lifecycle of a Ray cluster: setup, job execution, and teardown.
The below example showcases how to use the ``@ray.task`` decorator to manage the full lifecycle of a Ray cluster: setup, job execution, and teardown. The configuration for the decorator can provided statically or at runtime.

This approach is ideal for jobs that require a dedicated, short-lived cluster, optimizing resource usage by cleaning up after task completion

Expand Down
3 changes: 3 additions & 0 deletions docs/getting_started/code_samples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ The below example showcases how to use the ``@ray.task`` decorator to manage the

This approach is ideal for jobs that require a dedicated, short-lived cluster, optimizing resource usage by cleaning up after task completion.

.. note::
Configuration can be specified as a dictionary, either statically or dynamically at runtime as needed. We can also provide additional inputs while generating dynamic configurations. See example dags for reference.

.. literalinclude:: ../../example_dags/ray_taskflow_example.py
:language: python
:linenos:
Expand Down
60 changes: 60 additions & 0 deletions example_dags/ray_taskflow_example_dynamic_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from datetime import datetime
from pathlib import Path

from airflow.decorators import dag, task

from ray_provider.decorators.ray import ray


def generate_config(custom_memory: int, **context):

CONN_ID = "ray_conn"
RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml"
FOLDER_PATH = Path(__file__).parent / "ray_scripts"

return {
"conn_id": CONN_ID,
"runtime_env": {"working_dir": str(FOLDER_PATH), "pip": ["numpy"]},
"num_cpus": 1,
"num_gpus": 0,
"memory": custom_memory,
"poll_interval": 5,
"ray_cluster_yaml": str(RAY_SPEC),
"xcom_task_key": "dashboard",
"execution_date": str(context.get("execution_date")),
}


@dag(
dag_id="Ray_Taskflow_Example_Dynamic_Config",
start_date=datetime(2023, 1, 1),
schedule=None,
catchup=False,
tags=["ray", "example"],
)
def ray_taskflow_dag():
@task
def generate_data():
return [1, 2, 3]

@ray.task(config=generate_config, custom_memory=1024)
def process_data_with_ray(data):
import numpy as np
import ray

@ray.remote
def square(x):
return x**2

data = np.array(data)
futures = [square.remote(x) for x in data]
results = ray.get(futures)
mean = np.mean(results)
print(f"Mean of this population is {mean}")
return mean

data = generate_data()
process_data_with_ray(data)


ray_example_dag = ray_taskflow_dag()
2 changes: 1 addition & 1 deletion ray_provider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

__version__ = "0.2.1"
__version__ = "0.3.0a6"

from typing import Any

Expand Down
106 changes: 57 additions & 49 deletions ray_provider/decorators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,54 +31,23 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob):

template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs")

def __init__(self, config: dict[str, Any], **kwargs: Any) -> None:
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory", None)
self.ray_resources: dict[str, Any] | None = config.get("resources", None)
self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None)
self.update_if_exists: bool = config.get("update_if_exists", False)
self.kuberay_version: str = config.get("kuberay_version", "1.0.0")
self.gpu_device_plugin_yaml: str = config.get(
"gpu_device_plugin_yaml",
"https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml",
)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None:
job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key", None)

self.config = config
self.kwargs = kwargs
super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs)

if not isinstance(self.num_cpus, (int, float)):
raise TypeError("num_cpus should be an integer or float value")
if not isinstance(self.num_gpus, (int, float)):
raise TypeError("num_gpus should be an integer or float value")

super().__init__(
conn_id=self.conn_id,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
num_cpus=self.num_cpus,
num_gpus=self.num_gpus,
memory=self.memory,
resources=self.ray_resources,
ray_cluster_yaml=self.ray_cluster_yaml,
update_if_exists=self.update_if_exists,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
fetch_logs=self.fetch_logs,
wait_for_completion=self.wait_for_completion,
job_timeout_seconds=job_timeout_seconds,
poll_interval=self.poll_interval,
xcom_task_key=self.xcom_task_key,
**kwargs,
)
def get_config(self, context: Context, config: Callable[..., dict[str, Any]], **kwargs: Any) -> dict[str, Any]:
config_params = inspect.signature(config).parameters

config_kwargs = {k: v for k, v in kwargs.items() if k in config_params and k != "context"}

if "context" in config_params:
config_kwargs["context"] = context

# Call config with the prepared arguments
return config(**config_kwargs)

def execute(self, context: Context) -> Any:
"""
Expand All @@ -88,8 +57,42 @@ def execute(self, context: Context) -> Any:
:return: The result of the Ray job execution.
:raises AirflowException: If job submission fails.
"""
tmp_dir = None
temp_dir = None
try:
# Generate the configuration
if callable(self.config):
config = self.get_config(context=context, config=self.config, **self.kwargs)
else:
config = self.config

# Prepare Ray job parameters
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory", None)
self.ray_resources: dict[str, Any] | None = config.get("resources", None)
self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml", None)
self.update_if_exists: bool = config.get("update_if_exists", False)
self.kuberay_version: str = config.get("kuberay_version", "1.0.0")
self.gpu_device_plugin_yaml: str = config.get(
"gpu_device_plugin_yaml",
"https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml",
)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
self.job_timeout_seconds: int = config.get("job_timeout_seconds", 600)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key", None)

if not isinstance(self.num_cpus, (int, float)):
raise TypeError("num_cpus should be an integer or float value")
if not isinstance(self.num_gpus, (int, float)):
raise TypeError("num_gpus should be an integer or float value")

if self.is_decorated_function:
self.log.info(
f"Entrypoint is not provided, is_decorated_function is set to {self.is_decorated_function}"
Expand Down Expand Up @@ -126,8 +129,8 @@ def execute(self, context: Context) -> Any:
self.log.error(f"Failed during execution with error: {e}")
raise AirflowException("Job submission failed") from e
finally:
if tmp_dir and os.path.exists(tmp_dir):
shutil.rmtree(tmp_dir)
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)

def _extract_function_body(self, source: str) -> str:
"""Extract the function, excluding only the ray.task decorator."""
Expand All @@ -146,19 +149,24 @@ class ray:
def task(
python_callable: Callable[..., Any] | None = None,
multiple_outputs: bool | None = None,
config: Callable[[], dict[str, Any]] | dict[str, Any] | None = None,
**kwargs: Any,
) -> TaskDecorator:
"""
Decorator to define a task that submits a Ray job.

:param python_callable: The callable function to decorate.
:param multiple_outputs: If True, will return multiple outputs.
:param config: A dictionary of configuration or a callable that returns a dictionary.
:param kwargs: Additional keyword arguments.
:return: The decorated task.
"""
if config is None:
config = {}

return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=multiple_outputs,
decorated_operator_class=_RayDecoratedOperator,
config=config,
**kwargs,
)
11 changes: 9 additions & 2 deletions ray_provider/operators/ray.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import traceback
from datetime import timedelta
from functools import cached_property
from typing import Any
Expand Down Expand Up @@ -281,6 +282,10 @@ def execute(self, context: Context) -> str:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

job_timeout_seconds = self.job_timeout_seconds
if isinstance(self.job_timeout_seconds, int):
job_timeout_seconds = timedelta(seconds=self.job_timeout_seconds) if self.job_timeout_seconds > 0 else None

if current_status not in self.terminal_states:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
Expand All @@ -294,7 +299,7 @@ def execute(self, context: Context) -> str:
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=self.job_timeout_seconds,
timeout=job_timeout_seconds,
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
Expand All @@ -308,8 +313,10 @@ def execute(self, context: Context) -> str:
)
return self.job_id
except Exception as e:
self._delete_cluster()
error_details = traceback.format_exc()
self.log.info(error_details)
raise AirflowException(f"SubmitRayJob operator failed due to {e}. Cleaning up resources...")
self._delete_cluster()

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Expand Down
Loading