-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ExternalDeploymentSensor (#1472)
* Add ExternalDeploymentSensor Introduce Astro Cloud connection type Add ExternalDeploymentSensor, which uses Astro API to monitor Astro DAG status or task status. Sensor expects external_dag_id as a required param and external_task_id as an optional param if external_task_id is provided then it monitors a task instance of task external_task_id otherwise monitors the overall status of external_dag_id. Once the sensor execution starts it fetches the dag run for external_dag_id with status running or queued if found then waits for it to succeed or fail in the trigger component. if not found then return immediately and mark it a success. Assumption When this sensor starts running it assumes that the dag it is monitoring is either in a running or queued state
- Loading branch information
1 parent
5961420
commit 3ca9989
Showing
10 changed files
with
855 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import time | ||
from datetime import datetime | ||
|
||
from airflow import DAG | ||
from airflow.decorators import task | ||
from airflow.operators.trigger_dagrun import TriggerDagRunOperator | ||
|
||
from astronomer.providers.core.sensors.astro import ExternalDeploymentSensor | ||
|
||
with DAG( | ||
dag_id="example_astro_task", | ||
start_date=datetime(2022, 1, 1), | ||
schedule=None, | ||
catchup=False, | ||
tags=["example", "async", "core"], | ||
): | ||
ExternalDeploymentSensor( | ||
task_id="test1", | ||
external_dag_id="example_wait_to_test_example_astro_task", | ||
) | ||
|
||
ExternalDeploymentSensor( | ||
task_id="test2", | ||
external_dag_id="example_wait_to_test_example_astro_task", | ||
external_task_id="wait_for_2_min", | ||
) | ||
|
||
with DAG( | ||
dag_id="example_wait_to_test_example_astro_task", | ||
start_date=datetime(2022, 1, 1), | ||
schedule=None, | ||
catchup=False, | ||
tags=["example", "async", "core"], | ||
): | ||
|
||
@task | ||
def wait_for_2_min() -> None: | ||
"""Wait for 2 min.""" | ||
time.sleep(120) | ||
|
||
wait_for_2_min() | ||
|
||
|
||
with DAG( | ||
dag_id="trigger_astro_test_and_example", | ||
start_date=datetime(2022, 1, 1), | ||
schedule=None, | ||
catchup=False, | ||
tags=["example", "async", "core"], | ||
): | ||
run_wait_dag = TriggerDagRunOperator( | ||
task_id="run_wait_dag", | ||
trigger_dag_id="example_wait_to_test_example_astro_task", | ||
wait_for_completion=False, | ||
) | ||
|
||
run_astro_dag = TriggerDagRunOperator( | ||
task_id="run_astro_dag", | ||
trigger_dag_id="example_astro_task", | ||
wait_for_completion=False, | ||
) | ||
|
||
run_wait_dag >> run_astro_dag |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from typing import Any | ||
from urllib.parse import quote | ||
|
||
import requests | ||
from aiohttp import ClientSession | ||
from airflow.exceptions import AirflowException | ||
from airflow.hooks.base import BaseHook | ||
|
||
|
||
class AstroHook(BaseHook): | ||
""" | ||
Custom Apache Airflow Hook for interacting with Astro Cloud API. | ||
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. | ||
""" | ||
|
||
conn_name_attr = "astro_cloud_conn_id" | ||
default_conn_name = "astro_cloud_default" | ||
conn_type = "Astro Cloud" | ||
hook_name = "Astro Cloud" | ||
|
||
def __init__(self, astro_cloud_conn_id: str = "astro_cloud_conn_id"): | ||
super().__init__() | ||
self.astro_cloud_conn_id = astro_cloud_conn_id | ||
|
||
@classmethod | ||
def get_ui_field_behaviour(cls) -> dict[str, Any]: | ||
""" | ||
Returns UI field behavior customization for the Astro Cloud connection. | ||
This method defines hidden fields, relabeling, and placeholders for UI display. | ||
""" | ||
return { | ||
"hidden_fields": ["login", "port", "schema", "extra"], | ||
"relabeling": { | ||
"password": "Astro Cloud API Token", | ||
}, | ||
"placeholders": { | ||
"host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x", | ||
"password": "Astro API JWT Token", | ||
}, | ||
} | ||
|
||
def get_conn(self) -> tuple[str, str]: | ||
"""Retrieves the Astro Cloud connection details.""" | ||
conn = BaseHook.get_connection(self.astro_cloud_conn_id) | ||
base_url = conn.host or os.environ.get("AIRFLOW__WEBSERVER__BASE_URL") | ||
if base_url is None: | ||
raise AirflowException(f"Airflow host is missing in connection {self.astro_cloud_conn_id}") | ||
token = conn.password | ||
if token is None: | ||
raise AirflowException(f"Astro API token is missing in connection {self.astro_cloud_conn_id}") | ||
return base_url, token | ||
|
||
@property | ||
def _headers(self) -> dict[str, str]: | ||
"""Generates and returns headers for Astro Cloud API requests.""" | ||
_, token = self.get_conn() | ||
headers = {"accept": "application/json", "Authorization": f"Bearer {token}"} | ||
return headers | ||
|
||
def get_dag_runs(self, external_dag_id: str) -> list[dict[str, str]]: | ||
""" | ||
Retrieves information about running or queued DAG runs. | ||
:param external_dag_id: External ID of the DAG. | ||
""" | ||
base_url, _ = self.get_conn() | ||
path = f"/api/v1/dags/{external_dag_id}/dagRuns" | ||
params: dict[str, int | str | list[str]] = { | ||
"limit": 1, | ||
"state": ["running", "queued"], | ||
"order_by": "-execution_date", | ||
} | ||
url = f"{base_url}{path}" | ||
response = requests.get(url, headers=self._headers, params=params) | ||
response.raise_for_status() | ||
data: dict[str, list[dict[str, str]]] = response.json() | ||
return data["dag_runs"] | ||
|
||
def get_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None: | ||
""" | ||
Retrieves information about a specific DAG run. | ||
:param external_dag_id: External ID of the DAG. | ||
:param dag_run_id: ID of the DAG run. | ||
""" | ||
base_url, _ = self.get_conn() | ||
dag_run_id = quote(dag_run_id) | ||
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}" | ||
url = f"{base_url}{path}" | ||
response = requests.get(url, headers=self._headers) | ||
response.raise_for_status() | ||
dr: dict[str, Any] = response.json() | ||
return dr | ||
|
||
async def get_a_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None: | ||
""" | ||
Retrieves information about a specific DAG run. | ||
:param external_dag_id: External ID of the DAG. | ||
:param dag_run_id: ID of the DAG run. | ||
""" | ||
base_url, _ = self.get_conn() | ||
dag_run_id = quote(dag_run_id) | ||
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}" | ||
url = f"{base_url}{path}" | ||
|
||
async with ClientSession(headers=self._headers) as session: | ||
async with session.get(url) as response: | ||
response.raise_for_status() | ||
dr: dict[str, Any] = await response.json() | ||
return dr | ||
|
||
def get_task_instance( | ||
self, external_dag_id: str, dag_run_id: str, external_task_id: str | ||
) -> dict[str, Any] | None: | ||
""" | ||
Retrieves information about a specific task instance within a DAG run. | ||
:param external_dag_id: External ID of the DAG. | ||
:param dag_run_id: ID of the DAG run. | ||
:param external_task_id: External ID of the task. | ||
""" | ||
base_url, _ = self.get_conn() | ||
dag_run_id = quote(dag_run_id) | ||
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}" | ||
url = f"{base_url}{path}" | ||
response = requests.get(url, headers=self._headers) | ||
response.raise_for_status() | ||
ti: dict[str, Any] = response.json() | ||
return ti | ||
|
||
async def get_a_task_instance( | ||
self, external_dag_id: str, dag_run_id: str, external_task_id: str | ||
) -> dict[str, Any] | None: | ||
""" | ||
Retrieves information about a specific task instance within a DAG run. | ||
:param external_dag_id: External ID of the DAG. | ||
:param dag_run_id: ID of the DAG run. | ||
:param external_task_id: External ID of the task. | ||
""" | ||
base_url, _ = self.get_conn() | ||
dag_run_id = quote(dag_run_id) | ||
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}" | ||
url = f"{base_url}{path}" | ||
|
||
async with ClientSession(headers=self._headers) as session: | ||
async with session.get(url) as response: | ||
response.raise_for_status() | ||
ti: dict[str, Any] = await response.json() | ||
return ti |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from __future__ import annotations | ||
|
||
import datetime | ||
from typing import Any, cast | ||
|
||
from airflow.exceptions import AirflowException, AirflowSkipException | ||
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue | ||
|
||
from astronomer.providers.core.hooks.astro import AstroHook | ||
from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger | ||
from astronomer.providers.utils.typing_compat import Context | ||
|
||
|
||
class ExternalDeploymentSensor(BaseSensorOperator): | ||
""" | ||
Custom Apache Airflow sensor for monitoring external deployments using Astro Cloud. | ||
:param external_dag_id: External ID of the DAG being monitored. | ||
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. | ||
Defaults to "astro_cloud_default". | ||
:param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG. | ||
:param kwargs: Additional keyword arguments passed to the BaseSensorOperator constructor. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
external_dag_id: str, | ||
astro_cloud_conn_id: str = "astro_cloud_default", | ||
external_task_id: str | None = None, | ||
**kwargs: Any, | ||
): | ||
super().__init__(**kwargs) | ||
self.astro_cloud_conn_id = astro_cloud_conn_id | ||
self.external_task_id = external_task_id | ||
self.external_dag_id = external_dag_id | ||
self._dag_run_id: str = "" | ||
|
||
def poke(self, context: Context) -> bool | PokeReturnValue: | ||
""" | ||
Check the status of a DAG/task in another deployment. | ||
Queries Airflow's REST API for the status of the specified DAG or task instance. | ||
Returns True if successful, False otherwise. | ||
:param context: The task execution context. | ||
""" | ||
hook = AstroHook(self.astro_cloud_conn_id) | ||
dag_runs: list[dict[str, Any]] = hook.get_dag_runs(self.external_dag_id) | ||
if not dag_runs: | ||
self.log.info("No DAG runs found for DAG %s", self.external_dag_id) | ||
return True | ||
self._dag_run_id = cast(str, dag_runs[0]["dag_run_id"]) | ||
if self.external_task_id is not None: | ||
task_instance = hook.get_task_instance( | ||
self.external_dag_id, self._dag_run_id, self.external_task_id | ||
) | ||
task_state = task_instance.get("state") if task_instance else None | ||
if task_state == "success": | ||
return True | ||
else: | ||
state = dag_runs[0].get("state") | ||
if state == "success": | ||
return True | ||
return False | ||
|
||
def execute(self, context: Context) -> Any: | ||
""" | ||
Executes the sensor. | ||
If the external deployment is not successful, it defers the execution using an AstroDeploymentTrigger. | ||
:param context: The task execution context. | ||
""" | ||
if not self.poke(context): | ||
self.defer( | ||
timeout=datetime.timedelta(seconds=self.timeout), | ||
trigger=AstroDeploymentTrigger( | ||
astro_cloud_conn_id=self.astro_cloud_conn_id, | ||
external_task_id=self.external_task_id, | ||
external_dag_id=self.external_dag_id, | ||
poke_interval=self.poke_interval, | ||
dag_run_id=self._dag_run_id, | ||
), | ||
method_name="execute_complete", | ||
) | ||
|
||
def execute_complete(self, context: Context, event: dict[str, str]) -> None: | ||
""" | ||
Handles the completion event from the deferred execution. | ||
Raises AirflowSkipException if the upstream job failed and `soft_fail` is True. | ||
Otherwise, raises AirflowException. | ||
:param context: The task execution context. | ||
:param event: The event dictionary received from the deferred execution. | ||
""" | ||
if event.get("status") == "failed": | ||
if self.soft_fail: | ||
raise AirflowSkipException("Upstream job failed. Skipping the task.") | ||
raise AirflowException("Upstream job failed.") |
Oops, something went wrong.