From 5304103b61f9c3a94163db11a29c43d422f7b075 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Thu, 18 Jan 2024 16:08:18 +0530 Subject: [PATCH] Deprecate AzureDataFactoryRunPipelineOperatorAsync Deprecate AzureDataFactoryRunPipelineOperatorAsync and proxy it to its Airflow OSS provider's counterpart related: #1412 --- .../microsoft/azure/operators/data_factory.py | 107 +++------------- .../microsoft/azure/triggers/data_factory.py | 21 ++-- .../azure/operators/test_data_factory.py | 114 ++---------------- 3 files changed, 37 insertions(+), 205 deletions(-) diff --git a/astronomer/providers/microsoft/azure/operators/data_factory.py b/astronomer/providers/microsoft/azure/operators/data_factory.py index 85bcdf2e1..811948e6c 100644 --- a/astronomer/providers/microsoft/azure/operators/data_factory.py +++ b/astronomer/providers/microsoft/azure/operators/data_factory.py @@ -1,104 +1,25 @@ -import time -from typing import Dict +import warnings -from airflow.exceptions import AirflowException -from airflow.providers.microsoft.azure.hooks.data_factory import ( - AzureDataFactoryHook, - AzureDataFactoryPipelineRunException, - AzureDataFactoryPipelineRunStatus, -) from airflow.providers.microsoft.azure.operators.data_factory import ( AzureDataFactoryRunPipelineOperator, ) -from astronomer.providers.microsoft.azure.triggers.data_factory import ( - AzureDataFactoryTrigger, -) -from astronomer.providers.utils.typing_compat import Context - class AzureDataFactoryRunPipelineOperatorAsync(AzureDataFactoryRunPipelineOperator): """ - Executes a data factory pipeline asynchronously. - - :param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory. - :param pipeline_name: The name of the pipeline to execute. - :param wait_for_termination: Flag to wait on a pipeline run's termination. By default, this feature is - enabled but could be disabled to perform an asynchronous wait for a long-running pipeline execution - using the ``AzureDataFactoryPipelineRunSensor``. - :param resource_group_name: The resource group name. If a value is not passed in to the operator, the - ``AzureDataFactoryHook`` will attempt to use the resource group name provided in the corresponding - connection. - :param factory_name: The data factory name. If a value is not passed in to the operator, the - ``AzureDataFactoryHook`` will attempt to use the factory name name provided in the corresponding - connection. - :param reference_pipeline_run_id: The pipeline run identifier. If this run ID is specified the parameters - of the specified run will be used to create a new run. - :param is_recovery: Recovery mode flag. If recovery mode is set to `True`, the specified referenced - pipeline run and the new run will be grouped under the same ``groupId``. - :param start_activity_name: In recovery mode, the rerun will start from this activity. If not specified, - all activities will run. - :param start_from_failure: In recovery mode, if set to true, the rerun will start from failed activities. - The property will be used only if ``start_activity_name`` is not specified. - :param parameters: Parameters of the pipeline run. These parameters are referenced in a pipeline via - ``@pipeline().parameters.parameterName`` and will be used only if the ``reference_pipeline_run_id`` is - not specified. - :param timeout: Time in seconds to wait for a pipeline to reach a terminal status for non-asynchronous - waits. Used only if ``wait_for_termination`` is True. - :param check_interval: Time in seconds to check on a pipeline run's status for non-asynchronous waits. - Used only if ``wait_for_termination`` + This class is deprecated. + Use :class: `~airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryRunPipelineOperator` instead + and set `deferrable` param to `True` instead. """ - def execute(self, context: Context) -> None: - """Submits a job which generates a run_id and gets deferred""" - hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id) - response = hook.run_pipeline( - pipeline_name=self.pipeline_name, - resource_group_name=self.resource_group_name, - factory_name=self.factory_name, - reference_pipeline_run_id=self.reference_pipeline_run_id, - is_recovery=self.is_recovery, - start_activity_name=self.start_activity_name, - start_from_failure=self.start_from_failure, - parameters=self.parameters, + def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] + warnings.warn( + ( + "This class is deprecated. " + "Use `airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryRunPipelineOperator` " + "and set `deferrable` param to `True` instead." + ), + DeprecationWarning, + stacklevel=2, ) - run_id = vars(response)["run_id"] - context["ti"].xcom_push(key="run_id", value=run_id) - end_time = time.time() + self.timeout - - pipeline_run_status = hook.get_pipeline_run_status( - run_id=run_id, - resource_group_name=self.resource_group_name, - factory_name=self.factory_name, - ) - if pipeline_run_status not in AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES: - self.defer( - timeout=self.execution_timeout, - trigger=AzureDataFactoryTrigger( - azure_data_factory_conn_id=self.azure_data_factory_conn_id, - run_id=run_id, - wait_for_termination=self.wait_for_termination, - resource_group_name=self.resource_group_name, - factory_name=self.factory_name, - check_interval=self.check_interval, - end_time=end_time, - ), - method_name="execute_complete", - ) - elif pipeline_run_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED: - self.log.info("Pipeline run %s has completed successfully.", run_id) - elif pipeline_run_status in AzureDataFactoryPipelineRunStatus.FAILURE_STATES: - raise AzureDataFactoryPipelineRunException( - f"Pipeline run {run_id} has failed or has been cancelled." - ) - - def execute_complete(self, context: Context, event: Dict[str, str]) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if event["status"] == "error": - raise AirflowException(event["message"]) - self.log.info(event["message"]) + super().__init__(*args, deferrable=True, **kwargs) diff --git a/astronomer/providers/microsoft/azure/triggers/data_factory.py b/astronomer/providers/microsoft/azure/triggers/data_factory.py index 29b727bf3..d171bab5b 100644 --- a/astronomer/providers/microsoft/azure/triggers/data_factory.py +++ b/astronomer/providers/microsoft/azure/triggers/data_factory.py @@ -1,5 +1,6 @@ import asyncio import time +import warnings from typing import Any, AsyncIterator, Dict, List, Tuple from airflow.providers.microsoft.azure.hooks.data_factory import ( @@ -73,16 +74,8 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: class AzureDataFactoryTrigger(BaseTrigger): """ - AzureDataFactoryTrigger is triggered when Azure data factory pipeline job succeeded or failed. - When wait_for_termination is set to False it triggered immediately with success status - - :param run_id: Run id of a Azure data pipeline run job. - :param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory. - :param end_time: Time in seconds when triggers will timeout. - :param resource_group_name: The resource group name. - :param factory_name: The data factory name. - :param wait_for_termination: Flag to wait on a pipeline run's termination. - :param check_interval: Time in seconds to check on a pipeline run's status. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.microsoft.azure.triggers.data_factory.AzureDataFactoryTrigger` instead. """ QUEUED = "Queued" @@ -107,6 +100,14 @@ def __init__( wait_for_termination: bool = True, check_interval: int = 60, ): + warnings.warn( + ( + "This class is deprecated and will be removed in 2.0.0." + "Use :class: `~airflow.providers.microsoft.azure.triggers.data_factory.AzureDataFactoryTrigger` instead" + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.azure_data_factory_conn_id = azure_data_factory_conn_id self.check_interval = check_interval diff --git a/tests/microsoft/azure/operators/test_data_factory.py b/tests/microsoft/azure/operators/test_data_factory.py index 992ff2d62..cea6bbf05 100644 --- a/tests/microsoft/azure/operators/test_data_factory.py +++ b/tests/microsoft/azure/operators/test_data_factory.py @@ -1,109 +1,19 @@ -from unittest import mock - -import pytest -from airflow.exceptions import AirflowException, TaskDeferred -from airflow.providers.microsoft.azure.hooks.data_factory import ( - AzureDataFactoryPipelineRunException, - AzureDataFactoryPipelineRunStatus, -) +from airflow.providers.microsoft.azure.operators.data_factory import AzureDataFactoryRunPipelineOperator from astronomer.providers.microsoft.azure.operators.data_factory import ( AzureDataFactoryRunPipelineOperatorAsync, ) -from astronomer.providers.microsoft.azure.triggers.data_factory import ( - AzureDataFactoryTrigger, -) -from tests.utils.airflow_util import create_context - -AZ_PIPELINE_RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007" - -MODULE = "astronomer.providers.microsoft.azure.operators.data_factory" class TestAzureDataFactoryRunPipelineOperatorAsync: - OPERATOR = AzureDataFactoryRunPipelineOperatorAsync( - task_id="run_pipeline", - pipeline_name="pipeline", - parameters={"myParam": "value"}, - factory_name="factory_name", - resource_group_name="resource_group", - ) - - @mock.patch(f"{MODULE}.AzureDataFactoryRunPipelineOperatorAsync.defer") - @mock.patch( - "airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_pipeline_run_status", - return_value=AzureDataFactoryPipelineRunStatus.SUCCEEDED, - ) - @mock.patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.run_pipeline") - def test_azure_data_factory_run_pipeline_operator_async_succeeded_before_deferred( - self, mock_run_pipeline, mock_get_status, mock_defer - ): - class CreateRunResponse: - pass - - CreateRunResponse.run_id = AZ_PIPELINE_RUN_ID - mock_run_pipeline.return_value = CreateRunResponse - - self.OPERATOR.execute(context=create_context(self.OPERATOR)) - assert not mock_defer.called - - @pytest.mark.parametrize("status", AzureDataFactoryPipelineRunStatus.FAILURE_STATES) - @mock.patch(f"{MODULE}.AzureDataFactoryRunPipelineOperatorAsync.defer") - @mock.patch( - "airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_pipeline_run_status", - ) - @mock.patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.run_pipeline") - def test_azure_data_factory_run_pipeline_operator_async_error_before_deferred( - self, mock_run_pipeline, mock_get_status, mock_defer, status - ): - mock_get_status.return_value = status - - class CreateRunResponse: - pass - - CreateRunResponse.run_id = AZ_PIPELINE_RUN_ID - mock_run_pipeline.return_value = CreateRunResponse - - with pytest.raises(AzureDataFactoryPipelineRunException): - self.OPERATOR.execute(context=create_context(self.OPERATOR)) - assert not mock_defer.called - - @pytest.mark.parametrize("status", AzureDataFactoryPipelineRunStatus.INTERMEDIATE_STATES) - @mock.patch( - "airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_pipeline_run_status", - ) - @mock.patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.run_pipeline") - def test_azure_data_factory_run_pipeline_operator_async(self, mock_run_pipeline, mock_get_status, status): - """Assert that AzureDataFactoryRunPipelineOperatorAsync deferred""" - - class CreateRunResponse: - pass - - CreateRunResponse.run_id = AZ_PIPELINE_RUN_ID - mock_run_pipeline.return_value = CreateRunResponse - - with pytest.raises(TaskDeferred) as exc: - self.OPERATOR.execute(context=create_context(self.OPERATOR)) - - assert isinstance( - exc.value.trigger, AzureDataFactoryTrigger - ), "Trigger is not a AzureDataFactoryTrigger" - - def test_azure_data_factory_run_pipeline_operator_async_execute_complete_success(self): - """Assert that execute_complete log success message""" - - with mock.patch.object(self.OPERATOR.log, "info") as mock_log_info: - self.OPERATOR.execute_complete( - context=create_context(self.OPERATOR), - event={"status": "success", "message": "success", "run_id": AZ_PIPELINE_RUN_ID}, - ) - mock_log_info.assert_called_with("success") - - def test_azure_data_factory_run_pipeline_operator_async_execute_complete_fail(self): - """Assert that execute_complete raise exception on error""" - - with pytest.raises(AirflowException): - self.OPERATOR.execute_complete( - context=create_context(self.OPERATOR), - event={"status": "error", "message": "error", "run_id": AZ_PIPELINE_RUN_ID}, - ) + def test_init(self): + task = AzureDataFactoryRunPipelineOperatorAsync( + task_id="run_pipeline", + pipeline_name="pipeline", + parameters={"myParam": "value"}, + factory_name="factory_name", + resource_group_name="resource_group", + ) + + assert isinstance(task, AzureDataFactoryRunPipelineOperator) + assert task.deferrable is True