Skip to content

Commit

Permalink
Add ExternalDeploymentSensor (#1472)
Browse files Browse the repository at this point in the history
* Add ExternalDeploymentSensor

Introduce Astro Cloud connection type
Add ExternalDeploymentSensor, which uses Astro API to monitor Astro DAG status or task status.
Sensor expects external_dag_id as a required param and external_task_id as an optional param if external_task_id is provided then it monitors a task instance of task external_task_id otherwise monitors the overall status of external_dag_id.
Once the sensor execution starts it fetches the dag run for external_dag_id with status running or queued if found then waits for it to succeed or fail in the trigger component. if not found then return immediately and mark it a success.

Assumption
When this sensor starts running it assumes that the dag it is monitoring is either in a running or queued state
  • Loading branch information
pankajastro authored Feb 20, 2024
1 parent 5961420 commit 3ca9989
Show file tree
Hide file tree
Showing 10 changed files with 855 additions and 1 deletion.
63 changes: 63 additions & 0 deletions astronomer/providers/core/example_dags/example_astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import time
from datetime import datetime

from airflow import DAG
from airflow.decorators import task
from airflow.operators.trigger_dagrun import TriggerDagRunOperator

from astronomer.providers.core.sensors.astro import ExternalDeploymentSensor

with DAG(
dag_id="example_astro_task",
start_date=datetime(2022, 1, 1),
schedule=None,
catchup=False,
tags=["example", "async", "core"],
):
ExternalDeploymentSensor(
task_id="test1",
external_dag_id="example_wait_to_test_example_astro_task",
)

ExternalDeploymentSensor(
task_id="test2",
external_dag_id="example_wait_to_test_example_astro_task",
external_task_id="wait_for_2_min",
)

with DAG(
dag_id="example_wait_to_test_example_astro_task",
start_date=datetime(2022, 1, 1),
schedule=None,
catchup=False,
tags=["example", "async", "core"],
):

@task
def wait_for_2_min() -> None:
"""Wait for 2 min."""
time.sleep(120)

wait_for_2_min()


with DAG(
dag_id="trigger_astro_test_and_example",
start_date=datetime(2022, 1, 1),
schedule=None,
catchup=False,
tags=["example", "async", "core"],
):
run_wait_dag = TriggerDagRunOperator(
task_id="run_wait_dag",
trigger_dag_id="example_wait_to_test_example_astro_task",
wait_for_completion=False,
)

run_astro_dag = TriggerDagRunOperator(
task_id="run_astro_dag",
trigger_dag_id="example_astro_task",
wait_for_completion=False,
)

run_wait_dag >> run_astro_dag
Empty file.
156 changes: 156 additions & 0 deletions astronomer/providers/core/hooks/astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from __future__ import annotations

import os
from typing import Any
from urllib.parse import quote

import requests
from aiohttp import ClientSession
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook


class AstroHook(BaseHook):
"""
Custom Apache Airflow Hook for interacting with Astro Cloud API.
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials.
"""

conn_name_attr = "astro_cloud_conn_id"
default_conn_name = "astro_cloud_default"
conn_type = "Astro Cloud"
hook_name = "Astro Cloud"

def __init__(self, astro_cloud_conn_id: str = "astro_cloud_conn_id"):
super().__init__()
self.astro_cloud_conn_id = astro_cloud_conn_id

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""
Returns UI field behavior customization for the Astro Cloud connection.
This method defines hidden fields, relabeling, and placeholders for UI display.
"""
return {
"hidden_fields": ["login", "port", "schema", "extra"],
"relabeling": {
"password": "Astro Cloud API Token",
},
"placeholders": {
"host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x",
"password": "Astro API JWT Token",
},
}

def get_conn(self) -> tuple[str, str]:
"""Retrieves the Astro Cloud connection details."""
conn = BaseHook.get_connection(self.astro_cloud_conn_id)
base_url = conn.host or os.environ.get("AIRFLOW__WEBSERVER__BASE_URL")
if base_url is None:
raise AirflowException(f"Airflow host is missing in connection {self.astro_cloud_conn_id}")
token = conn.password
if token is None:
raise AirflowException(f"Astro API token is missing in connection {self.astro_cloud_conn_id}")
return base_url, token

@property
def _headers(self) -> dict[str, str]:
"""Generates and returns headers for Astro Cloud API requests."""
_, token = self.get_conn()
headers = {"accept": "application/json", "Authorization": f"Bearer {token}"}
return headers

def get_dag_runs(self, external_dag_id: str) -> list[dict[str, str]]:
"""
Retrieves information about running or queued DAG runs.
:param external_dag_id: External ID of the DAG.
"""
base_url, _ = self.get_conn()
path = f"/api/v1/dags/{external_dag_id}/dagRuns"
params: dict[str, int | str | list[str]] = {
"limit": 1,
"state": ["running", "queued"],
"order_by": "-execution_date",
}
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers, params=params)
response.raise_for_status()
data: dict[str, list[dict[str, str]]] = response.json()
return data["dag_runs"]

def get_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None:
"""
Retrieves information about a specific DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}"
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers)
response.raise_for_status()
dr: dict[str, Any] = response.json()
return dr

async def get_a_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None:
"""
Retrieves information about a specific DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}"
url = f"{base_url}{path}"

async with ClientSession(headers=self._headers) as session:
async with session.get(url) as response:
response.raise_for_status()
dr: dict[str, Any] = await response.json()
return dr

def get_task_instance(
self, external_dag_id: str, dag_run_id: str, external_task_id: str
) -> dict[str, Any] | None:
"""
Retrieves information about a specific task instance within a DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
:param external_task_id: External ID of the task.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}"
url = f"{base_url}{path}"
response = requests.get(url, headers=self._headers)
response.raise_for_status()
ti: dict[str, Any] = response.json()
return ti

async def get_a_task_instance(
self, external_dag_id: str, dag_run_id: str, external_task_id: str
) -> dict[str, Any] | None:
"""
Retrieves information about a specific task instance within a DAG run.
:param external_dag_id: External ID of the DAG.
:param dag_run_id: ID of the DAG run.
:param external_task_id: External ID of the task.
"""
base_url, _ = self.get_conn()
dag_run_id = quote(dag_run_id)
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}"
url = f"{base_url}{path}"

async with ClientSession(headers=self._headers) as session:
async with session.get(url) as response:
response.raise_for_status()
ti: dict[str, Any] = await response.json()
return ti
100 changes: 100 additions & 0 deletions astronomer/providers/core/sensors/astro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

import datetime
from typing import Any, cast

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue

from astronomer.providers.core.hooks.astro import AstroHook
from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger
from astronomer.providers.utils.typing_compat import Context


class ExternalDeploymentSensor(BaseSensorOperator):
"""
Custom Apache Airflow sensor for monitoring external deployments using Astro Cloud.
:param external_dag_id: External ID of the DAG being monitored.
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials.
Defaults to "astro_cloud_default".
:param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG.
:param kwargs: Additional keyword arguments passed to the BaseSensorOperator constructor.
"""

def __init__(
self,
external_dag_id: str,
astro_cloud_conn_id: str = "astro_cloud_default",
external_task_id: str | None = None,
**kwargs: Any,
):
super().__init__(**kwargs)
self.astro_cloud_conn_id = astro_cloud_conn_id
self.external_task_id = external_task_id
self.external_dag_id = external_dag_id
self._dag_run_id: str = ""

def poke(self, context: Context) -> bool | PokeReturnValue:
"""
Check the status of a DAG/task in another deployment.
Queries Airflow's REST API for the status of the specified DAG or task instance.
Returns True if successful, False otherwise.
:param context: The task execution context.
"""
hook = AstroHook(self.astro_cloud_conn_id)
dag_runs: list[dict[str, Any]] = hook.get_dag_runs(self.external_dag_id)
if not dag_runs:
self.log.info("No DAG runs found for DAG %s", self.external_dag_id)
return True
self._dag_run_id = cast(str, dag_runs[0]["dag_run_id"])
if self.external_task_id is not None:
task_instance = hook.get_task_instance(
self.external_dag_id, self._dag_run_id, self.external_task_id
)
task_state = task_instance.get("state") if task_instance else None
if task_state == "success":
return True
else:
state = dag_runs[0].get("state")
if state == "success":
return True
return False

def execute(self, context: Context) -> Any:
"""
Executes the sensor.
If the external deployment is not successful, it defers the execution using an AstroDeploymentTrigger.
:param context: The task execution context.
"""
if not self.poke(context):
self.defer(
timeout=datetime.timedelta(seconds=self.timeout),
trigger=AstroDeploymentTrigger(
astro_cloud_conn_id=self.astro_cloud_conn_id,
external_task_id=self.external_task_id,
external_dag_id=self.external_dag_id,
poke_interval=self.poke_interval,
dag_run_id=self._dag_run_id,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, str]) -> None:
"""
Handles the completion event from the deferred execution.
Raises AirflowSkipException if the upstream job failed and `soft_fail` is True.
Otherwise, raises AirflowException.
:param context: The task execution context.
:param event: The event dictionary received from the deferred execution.
"""
if event.get("status") == "failed":
if self.soft_fail:
raise AirflowSkipException("Upstream job failed. Skipping the task.")
raise AirflowException("Upstream job failed.")
Loading

0 comments on commit 3ca9989

Please sign in to comment.