Skip to content

Commit

Permalink
Deprecate AzureDataFactoryRunPipelineOperatorAsync
Browse files Browse the repository at this point in the history
Deprecate AzureDataFactoryRunPipelineOperatorAsync and proxy it to
its Airflow OSS provider's counterpart

related: #1412
  • Loading branch information
pankajkoti committed Jan 22, 2024
1 parent 4f02e6a commit 5304103
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 205 deletions.
107 changes: 14 additions & 93 deletions astronomer/providers/microsoft/azure/operators/data_factory.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,25 @@
import time
from typing import Dict
import warnings

from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.hooks.data_factory import (
AzureDataFactoryHook,
AzureDataFactoryPipelineRunException,
AzureDataFactoryPipelineRunStatus,
)
from airflow.providers.microsoft.azure.operators.data_factory import (
AzureDataFactoryRunPipelineOperator,
)

from astronomer.providers.microsoft.azure.triggers.data_factory import (
AzureDataFactoryTrigger,
)
from astronomer.providers.utils.typing_compat import Context


class AzureDataFactoryRunPipelineOperatorAsync(AzureDataFactoryRunPipelineOperator):
"""
Executes a data factory pipeline asynchronously.
:param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory.
:param pipeline_name: The name of the pipeline to execute.
:param wait_for_termination: Flag to wait on a pipeline run's termination. By default, this feature is
enabled but could be disabled to perform an asynchronous wait for a long-running pipeline execution
using the ``AzureDataFactoryPipelineRunSensor``.
:param resource_group_name: The resource group name. If a value is not passed in to the operator, the
``AzureDataFactoryHook`` will attempt to use the resource group name provided in the corresponding
connection.
:param factory_name: The data factory name. If a value is not passed in to the operator, the
``AzureDataFactoryHook`` will attempt to use the factory name name provided in the corresponding
connection.
:param reference_pipeline_run_id: The pipeline run identifier. If this run ID is specified the parameters
of the specified run will be used to create a new run.
:param is_recovery: Recovery mode flag. If recovery mode is set to `True`, the specified referenced
pipeline run and the new run will be grouped under the same ``groupId``.
:param start_activity_name: In recovery mode, the rerun will start from this activity. If not specified,
all activities will run.
:param start_from_failure: In recovery mode, if set to true, the rerun will start from failed activities.
The property will be used only if ``start_activity_name`` is not specified.
:param parameters: Parameters of the pipeline run. These parameters are referenced in a pipeline via
``@pipeline().parameters.parameterName`` and will be used only if the ``reference_pipeline_run_id`` is
not specified.
:param timeout: Time in seconds to wait for a pipeline to reach a terminal status for non-asynchronous
waits. Used only if ``wait_for_termination`` is True.
:param check_interval: Time in seconds to check on a pipeline run's status for non-asynchronous waits.
Used only if ``wait_for_termination``
This class is deprecated.
Use :class: `~airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryRunPipelineOperator` instead
and set `deferrable` param to `True` instead.
"""

def execute(self, context: Context) -> None:
"""Submits a job which generates a run_id and gets deferred"""
hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id)
response = hook.run_pipeline(
pipeline_name=self.pipeline_name,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
reference_pipeline_run_id=self.reference_pipeline_run_id,
is_recovery=self.is_recovery,
start_activity_name=self.start_activity_name,
start_from_failure=self.start_from_failure,
parameters=self.parameters,
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryRunPipelineOperator` "
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
run_id = vars(response)["run_id"]
context["ti"].xcom_push(key="run_id", value=run_id)
end_time = time.time() + self.timeout

pipeline_run_status = hook.get_pipeline_run_status(
run_id=run_id,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
if pipeline_run_status not in AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES:
self.defer(
timeout=self.execution_timeout,
trigger=AzureDataFactoryTrigger(
azure_data_factory_conn_id=self.azure_data_factory_conn_id,
run_id=run_id,
wait_for_termination=self.wait_for_termination,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
check_interval=self.check_interval,
end_time=end_time,
),
method_name="execute_complete",
)
elif pipeline_run_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED:
self.log.info("Pipeline run %s has completed successfully.", run_id)
elif pipeline_run_status in AzureDataFactoryPipelineRunStatus.FAILURE_STATES:
raise AzureDataFactoryPipelineRunException(
f"Pipeline run {run_id} has failed or has been cancelled."
)

def execute_complete(self, context: Context, event: Dict[str, str]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(event["message"])
super().__init__(*args, deferrable=True, **kwargs)
21 changes: 11 additions & 10 deletions astronomer/providers/microsoft/azure/triggers/data_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import time
import warnings
from typing import Any, AsyncIterator, Dict, List, Tuple

from airflow.providers.microsoft.azure.hooks.data_factory import (
Expand Down Expand Up @@ -73,16 +74,8 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:

class AzureDataFactoryTrigger(BaseTrigger):
"""
AzureDataFactoryTrigger is triggered when Azure data factory pipeline job succeeded or failed.
When wait_for_termination is set to False it triggered immediately with success status
:param run_id: Run id of a Azure data pipeline run job.
:param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory.
:param end_time: Time in seconds when triggers will timeout.
:param resource_group_name: The resource group name.
:param factory_name: The data factory name.
:param wait_for_termination: Flag to wait on a pipeline run's termination.
:param check_interval: Time in seconds to check on a pipeline run's status.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.microsoft.azure.triggers.data_factory.AzureDataFactoryTrigger` instead.
"""

QUEUED = "Queued"
Expand All @@ -107,6 +100,14 @@ def __init__(
wait_for_termination: bool = True,
check_interval: int = 60,
):
warnings.warn(
(
"This class is deprecated and will be removed in 2.0.0."
"Use :class: `~airflow.providers.microsoft.azure.triggers.data_factory.AzureDataFactoryTrigger` instead"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.azure_data_factory_conn_id = azure_data_factory_conn_id
self.check_interval = check_interval
Expand Down
114 changes: 12 additions & 102 deletions tests/microsoft/azure/operators/test_data_factory.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,19 @@
from unittest import mock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.microsoft.azure.hooks.data_factory import (
AzureDataFactoryPipelineRunException,
AzureDataFactoryPipelineRunStatus,
)
from airflow.providers.microsoft.azure.operators.data_factory import AzureDataFactoryRunPipelineOperator

from astronomer.providers.microsoft.azure.operators.data_factory import (
AzureDataFactoryRunPipelineOperatorAsync,
)
from astronomer.providers.microsoft.azure.triggers.data_factory import (
AzureDataFactoryTrigger,
)
from tests.utils.airflow_util import create_context

AZ_PIPELINE_RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007"

MODULE = "astronomer.providers.microsoft.azure.operators.data_factory"


class TestAzureDataFactoryRunPipelineOperatorAsync:
OPERATOR = AzureDataFactoryRunPipelineOperatorAsync(
task_id="run_pipeline",
pipeline_name="pipeline",
parameters={"myParam": "value"},
factory_name="factory_name",
resource_group_name="resource_group",
)

@mock.patch(f"{MODULE}.AzureDataFactoryRunPipelineOperatorAsync.defer")
@mock.patch(
"airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_pipeline_run_status",
return_value=AzureDataFactoryPipelineRunStatus.SUCCEEDED,
)
@mock.patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.run_pipeline")
def test_azure_data_factory_run_pipeline_operator_async_succeeded_before_deferred(
self, mock_run_pipeline, mock_get_status, mock_defer
):
class CreateRunResponse:
pass

CreateRunResponse.run_id = AZ_PIPELINE_RUN_ID
mock_run_pipeline.return_value = CreateRunResponse

self.OPERATOR.execute(context=create_context(self.OPERATOR))
assert not mock_defer.called

@pytest.mark.parametrize("status", AzureDataFactoryPipelineRunStatus.FAILURE_STATES)
@mock.patch(f"{MODULE}.AzureDataFactoryRunPipelineOperatorAsync.defer")
@mock.patch(
"airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_pipeline_run_status",
)
@mock.patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.run_pipeline")
def test_azure_data_factory_run_pipeline_operator_async_error_before_deferred(
self, mock_run_pipeline, mock_get_status, mock_defer, status
):
mock_get_status.return_value = status

class CreateRunResponse:
pass

CreateRunResponse.run_id = AZ_PIPELINE_RUN_ID
mock_run_pipeline.return_value = CreateRunResponse

with pytest.raises(AzureDataFactoryPipelineRunException):
self.OPERATOR.execute(context=create_context(self.OPERATOR))
assert not mock_defer.called

@pytest.mark.parametrize("status", AzureDataFactoryPipelineRunStatus.INTERMEDIATE_STATES)
@mock.patch(
"airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_pipeline_run_status",
)
@mock.patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.run_pipeline")
def test_azure_data_factory_run_pipeline_operator_async(self, mock_run_pipeline, mock_get_status, status):
"""Assert that AzureDataFactoryRunPipelineOperatorAsync deferred"""

class CreateRunResponse:
pass

CreateRunResponse.run_id = AZ_PIPELINE_RUN_ID
mock_run_pipeline.return_value = CreateRunResponse

with pytest.raises(TaskDeferred) as exc:
self.OPERATOR.execute(context=create_context(self.OPERATOR))

assert isinstance(
exc.value.trigger, AzureDataFactoryTrigger
), "Trigger is not a AzureDataFactoryTrigger"

def test_azure_data_factory_run_pipeline_operator_async_execute_complete_success(self):
"""Assert that execute_complete log success message"""

with mock.patch.object(self.OPERATOR.log, "info") as mock_log_info:
self.OPERATOR.execute_complete(
context=create_context(self.OPERATOR),
event={"status": "success", "message": "success", "run_id": AZ_PIPELINE_RUN_ID},
)
mock_log_info.assert_called_with("success")

def test_azure_data_factory_run_pipeline_operator_async_execute_complete_fail(self):
"""Assert that execute_complete raise exception on error"""

with pytest.raises(AirflowException):
self.OPERATOR.execute_complete(
context=create_context(self.OPERATOR),
event={"status": "error", "message": "error", "run_id": AZ_PIPELINE_RUN_ID},
)
def test_init(self):
task = AzureDataFactoryRunPipelineOperatorAsync(
task_id="run_pipeline",
pipeline_name="pipeline",
parameters={"myParam": "value"},
factory_name="factory_name",
resource_group_name="resource_group",
)

assert isinstance(task, AzureDataFactoryRunPipelineOperator)
assert task.deferrable is True

0 comments on commit 5304103

Please sign in to comment.