Skip to content

Commit

Permalink
Deprecate AzureDataFactoryPipelineRunStatusSensorAsync
Browse files Browse the repository at this point in the history
Deprecate AzureDataFactoryPipelineRunStatusSensorAsync and proxy it
to its Airflow OSS provider's counterpart

related: #1412
  • Loading branch information
pankajkoti committed Jan 22, 2024
1 parent e7bf96c commit 98edef8
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 106 deletions.
14 changes: 11 additions & 3 deletions astronomer/providers/microsoft/azure/hooks/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import inspect
import warnings
from functools import wraps
from typing import Any, TypeVar, Union, cast

Expand Down Expand Up @@ -68,13 +69,20 @@ async def bind_argument(arg: Any, default_key: str) -> None:

class AzureDataFactoryHookAsync(AzureDataFactoryHook):
"""
An Async Hook connects to Azure DataFactory to perform pipeline operations.
:param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id<howto/connection:adf>`.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook` instead.
"""

def __init__(self, azure_data_factory_conn_id: str):
"""Initialize the hook instance."""
warnings.warn(
(
"This class is deprecated and will be removed in 2.0.0."
"Use :class: `~airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook` instead."
),
DeprecationWarning,
stacklevel=2,
)
self._async_conn: DataFactoryManagementClient | None = None
self.conn_id = azure_data_factory_conn_id
super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
Expand Down
55 changes: 8 additions & 47 deletions astronomer/providers/microsoft/azure/sensors/data_factory.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,29 @@
import warnings
from datetime import timedelta
from typing import Any, Dict
from typing import Any

from airflow.providers.microsoft.azure.sensors.data_factory import (
AzureDataFactoryPipelineRunStatusSensor,
)

from astronomer.providers.microsoft.azure.triggers.data_factory import (
ADFPipelineRunStatusSensorTrigger,
)
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context
from airflow.providers.microsoft.azure.sensors.data_factory import AzureDataFactoryPipelineRunStatusSensor


class AzureDataFactoryPipelineRunStatusSensorAsync(AzureDataFactoryPipelineRunStatusSensor):
"""
Checks the status of a pipeline run.
:param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory.
:param run_id: The pipeline run identifier.
:param resource_group_name: The resource group name.
:param factory_name: The data factory name.
:param poll_interval: polling period in seconds to check for the status
This class is deprecated.
Use :class: `~airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryPipelineRunStatusSensor`
instead and set `deferrable` param to `True` instead.
"""

def __init__(
self,
*,
*args: Any,
poll_interval: float = 5,
**kwargs: Any,
):
# TODO: Remove once deprecated
if poll_interval:
self.poke_interval = poll_interval
kwargs["poke_interval"] = poll_interval
warnings.warn(
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=ADFPipelineRunStatusSensorTrigger(
run_id=self.run_id,
azure_data_factory_conn_id=self.azure_data_factory_conn_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: Dict[str, str]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event["message"])
super().__init__(*args, deferrable=True, **kwargs)
10 changes: 2 additions & 8 deletions astronomer/providers/microsoft/azure/triggers/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@

class ADFPipelineRunStatusSensorTrigger(BaseTrigger):
"""
ADFPipelineRunStatusSensorTrigger is fired as deferred class with params to run the task in trigger worker, when
ADF Pipeline is running
:param run_id: The pipeline run identifier.
:param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory.
:param poke_interval: polling period in seconds to check for the status
:param resource_group_name: The resource group name.
:param factory_name: The data factory name.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.microsoft.azure.triggers.data_factory.ADFPipelineRunStatusSensorTrigger` instead.
"""

def __init__(
Expand Down
58 changes: 10 additions & 48 deletions tests/microsoft/azure/sensors/test_data_factory.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,24 @@
from unittest import mock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.microsoft.azure.sensors.data_factory import AzureDataFactoryPipelineRunStatusSensor

from astronomer.providers.microsoft.azure.sensors.data_factory import (
AzureDataFactoryPipelineRunStatusSensorAsync,
)
from astronomer.providers.microsoft.azure.triggers.data_factory import (
ADFPipelineRunStatusSensorTrigger,
)
from tests.utils.airflow_util import create_context

MODULE = "astronomer.providers.microsoft.azure.sensors.data_factory"


class TestAzureDataFactoryPipelineRunStatusSensorAsync:
RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007"
SENSOR = AzureDataFactoryPipelineRunStatusSensorAsync(
task_id="pipeline_run_sensor_async",
run_id=RUN_ID,
factory_name="factory_name",
resource_group_name="resource_group_name",
)

@mock.patch(f"{MODULE}.AzureDataFactoryPipelineRunStatusSensorAsync.defer")
@mock.patch(f"{MODULE}.AzureDataFactoryPipelineRunStatusSensorAsync.poke", return_value=True)
def test_adf_pipeline_status_sensor_async_finish_before_deferred(
self,
mock_poke,
mock_defer,
):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.SENSOR.execute(create_context(self.SENSOR))
assert not mock_defer.called

@mock.patch(f"{MODULE}.AzureDataFactoryPipelineRunStatusSensorAsync.poke", return_value=False)
def test_adf_pipeline_status_sensor_async(self, mock_poke):
"""Assert execute method defer for Azure Data factory pipeline run status sensor"""

with pytest.raises(TaskDeferred) as exc:
self.SENSOR.execute(create_context(self.SENSOR))
assert isinstance(
exc.value.trigger, ADFPipelineRunStatusSensorTrigger
), "Trigger is not a ADFPipelineRunStatusSensorTrigger"

def test_adf_pipeline_status_sensor_execute_complete_success(self):
"""Assert execute_complete log success message when trigger fire with target status"""

msg = f"Pipeline run {self.RUN_ID} has been succeeded."
with mock.patch.object(self.SENSOR.log, "info") as mock_log_info:
self.SENSOR.execute_complete(context={}, event={"status": "success", "message": msg})
mock_log_info.assert_called_with(msg)

def test_adf_pipeline_status_sensor_execute_complete_failure(self):
"""Assert execute_complete method fail"""
def test_init(self):
task = AzureDataFactoryPipelineRunStatusSensorAsync(
task_id="pipeline_run_sensor_async",
run_id=self.RUN_ID,
factory_name="factory_name",
resource_group_name="resource_group_name",
)

with pytest.raises(AirflowException):
self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""})
assert isinstance(task, AzureDataFactoryPipelineRunStatusSensor)
assert task.deferrable is True

def test_poll_interval_deprecation_warning(self):
"""Test DeprecationWarning for AzureDataFactoryPipelineRunStatusSensorAsync by setting param poll_interval"""
Expand Down

0 comments on commit 98edef8

Please sign in to comment.