-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Deprecate AzureDataFactoryRunPipelineOperatorAsync
Deprecate AzureDataFactoryRunPipelineOperatorAsync and proxy it to its Airflow OSS provider's counterpart related: #1412
- Loading branch information
1 parent
e7bf96c
commit 6ad0a03
Showing
4 changed files
with
48 additions
and
208 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
107 changes: 14 additions & 93 deletions
107
astronomer/providers/microsoft/azure/operators/data_factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,104 +1,25 @@ | ||
import time | ||
from typing import Dict | ||
import warnings | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.providers.microsoft.azure.hooks.data_factory import ( | ||
AzureDataFactoryHook, | ||
AzureDataFactoryPipelineRunException, | ||
AzureDataFactoryPipelineRunStatus, | ||
) | ||
from airflow.providers.microsoft.azure.operators.data_factory import ( | ||
AzureDataFactoryRunPipelineOperator, | ||
) | ||
|
||
from astronomer.providers.microsoft.azure.triggers.data_factory import ( | ||
AzureDataFactoryTrigger, | ||
) | ||
from astronomer.providers.utils.typing_compat import Context | ||
|
||
|
||
class AzureDataFactoryRunPipelineOperatorAsync(AzureDataFactoryRunPipelineOperator): | ||
""" | ||
Executes a data factory pipeline asynchronously. | ||
:param azure_data_factory_conn_id: The connection identifier for connecting to Azure Data Factory. | ||
:param pipeline_name: The name of the pipeline to execute. | ||
:param wait_for_termination: Flag to wait on a pipeline run's termination. By default, this feature is | ||
enabled but could be disabled to perform an asynchronous wait for a long-running pipeline execution | ||
using the ``AzureDataFactoryPipelineRunSensor``. | ||
:param resource_group_name: The resource group name. If a value is not passed in to the operator, the | ||
``AzureDataFactoryHook`` will attempt to use the resource group name provided in the corresponding | ||
connection. | ||
:param factory_name: The data factory name. If a value is not passed in to the operator, the | ||
``AzureDataFactoryHook`` will attempt to use the factory name name provided in the corresponding | ||
connection. | ||
:param reference_pipeline_run_id: The pipeline run identifier. If this run ID is specified the parameters | ||
of the specified run will be used to create a new run. | ||
:param is_recovery: Recovery mode flag. If recovery mode is set to `True`, the specified referenced | ||
pipeline run and the new run will be grouped under the same ``groupId``. | ||
:param start_activity_name: In recovery mode, the rerun will start from this activity. If not specified, | ||
all activities will run. | ||
:param start_from_failure: In recovery mode, if set to true, the rerun will start from failed activities. | ||
The property will be used only if ``start_activity_name`` is not specified. | ||
:param parameters: Parameters of the pipeline run. These parameters are referenced in a pipeline via | ||
``@pipeline().parameters.parameterName`` and will be used only if the ``reference_pipeline_run_id`` is | ||
not specified. | ||
:param timeout: Time in seconds to wait for a pipeline to reach a terminal status for non-asynchronous | ||
waits. Used only if ``wait_for_termination`` is True. | ||
:param check_interval: Time in seconds to check on a pipeline run's status for non-asynchronous waits. | ||
Used only if ``wait_for_termination`` | ||
This class is deprecated. | ||
Use :class: `~airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryRunPipelineOperator` instead | ||
and set `deferrable` param to `True` instead. | ||
""" | ||
|
||
def execute(self, context: Context) -> None: | ||
"""Submits a job which generates a run_id and gets deferred""" | ||
hook = AzureDataFactoryHook(azure_data_factory_conn_id=self.azure_data_factory_conn_id) | ||
response = hook.run_pipeline( | ||
pipeline_name=self.pipeline_name, | ||
resource_group_name=self.resource_group_name, | ||
factory_name=self.factory_name, | ||
reference_pipeline_run_id=self.reference_pipeline_run_id, | ||
is_recovery=self.is_recovery, | ||
start_activity_name=self.start_activity_name, | ||
start_from_failure=self.start_from_failure, | ||
parameters=self.parameters, | ||
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] | ||
warnings.warn( | ||
( | ||
"This class is deprecated. " | ||
"Use `airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryRunPipelineOperator` " | ||
"and set `deferrable` param to `True` instead." | ||
), | ||
DeprecationWarning, | ||
stacklevel=2, | ||
) | ||
run_id = vars(response)["run_id"] | ||
context["ti"].xcom_push(key="run_id", value=run_id) | ||
end_time = time.time() + self.timeout | ||
|
||
pipeline_run_status = hook.get_pipeline_run_status( | ||
run_id=run_id, | ||
resource_group_name=self.resource_group_name, | ||
factory_name=self.factory_name, | ||
) | ||
if pipeline_run_status not in AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES: | ||
self.defer( | ||
timeout=self.execution_timeout, | ||
trigger=AzureDataFactoryTrigger( | ||
azure_data_factory_conn_id=self.azure_data_factory_conn_id, | ||
run_id=run_id, | ||
wait_for_termination=self.wait_for_termination, | ||
resource_group_name=self.resource_group_name, | ||
factory_name=self.factory_name, | ||
check_interval=self.check_interval, | ||
end_time=end_time, | ||
), | ||
method_name="execute_complete", | ||
) | ||
elif pipeline_run_status == AzureDataFactoryPipelineRunStatus.SUCCEEDED: | ||
self.log.info("Pipeline run %s has completed successfully.", run_id) | ||
elif pipeline_run_status in AzureDataFactoryPipelineRunStatus.FAILURE_STATES: | ||
raise AzureDataFactoryPipelineRunException( | ||
f"Pipeline run {run_id} has failed or has been cancelled." | ||
) | ||
|
||
def execute_complete(self, context: Context, event: Dict[str, str]) -> None: | ||
""" | ||
Callback for when the trigger fires - returns immediately. | ||
Relies on trigger to throw an exception, otherwise it assumes execution was | ||
successful. | ||
""" | ||
if event: | ||
if event["status"] == "error": | ||
raise AirflowException(event["message"]) | ||
self.log.info(event["message"]) | ||
super().__init__(*args, deferrable=True, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,109 +1,19 @@ | ||
from unittest import mock | ||
|
||
import pytest | ||
from airflow.exceptions import AirflowException, TaskDeferred | ||
from airflow.providers.microsoft.azure.hooks.data_factory import ( | ||
AzureDataFactoryPipelineRunException, | ||
AzureDataFactoryPipelineRunStatus, | ||
) | ||
from airflow.providers.microsoft.azure.operators.data_factory import AzureDataFactoryRunPipelineOperator | ||
|
||
from astronomer.providers.microsoft.azure.operators.data_factory import ( | ||
AzureDataFactoryRunPipelineOperatorAsync, | ||
) | ||
from astronomer.providers.microsoft.azure.triggers.data_factory import ( | ||
AzureDataFactoryTrigger, | ||
) | ||
from tests.utils.airflow_util import create_context | ||
|
||
AZ_PIPELINE_RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007" | ||
|
||
MODULE = "astronomer.providers.microsoft.azure.operators.data_factory" | ||
|
||
|
||
class TestAzureDataFactoryRunPipelineOperatorAsync: | ||
OPERATOR = AzureDataFactoryRunPipelineOperatorAsync( | ||
task_id="run_pipeline", | ||
pipeline_name="pipeline", | ||
parameters={"myParam": "value"}, | ||
factory_name="factory_name", | ||
resource_group_name="resource_group", | ||
) | ||
|
||
@mock.patch(f"{MODULE}.AzureDataFactoryRunPipelineOperatorAsync.defer") | ||
@mock.patch( | ||
"airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_pipeline_run_status", | ||
return_value=AzureDataFactoryPipelineRunStatus.SUCCEEDED, | ||
) | ||
@mock.patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.run_pipeline") | ||
def test_azure_data_factory_run_pipeline_operator_async_succeeded_before_deferred( | ||
self, mock_run_pipeline, mock_get_status, mock_defer | ||
): | ||
class CreateRunResponse: | ||
pass | ||
|
||
CreateRunResponse.run_id = AZ_PIPELINE_RUN_ID | ||
mock_run_pipeline.return_value = CreateRunResponse | ||
|
||
self.OPERATOR.execute(context=create_context(self.OPERATOR)) | ||
assert not mock_defer.called | ||
|
||
@pytest.mark.parametrize("status", AzureDataFactoryPipelineRunStatus.FAILURE_STATES) | ||
@mock.patch(f"{MODULE}.AzureDataFactoryRunPipelineOperatorAsync.defer") | ||
@mock.patch( | ||
"airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_pipeline_run_status", | ||
) | ||
@mock.patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.run_pipeline") | ||
def test_azure_data_factory_run_pipeline_operator_async_error_before_deferred( | ||
self, mock_run_pipeline, mock_get_status, mock_defer, status | ||
): | ||
mock_get_status.return_value = status | ||
|
||
class CreateRunResponse: | ||
pass | ||
|
||
CreateRunResponse.run_id = AZ_PIPELINE_RUN_ID | ||
mock_run_pipeline.return_value = CreateRunResponse | ||
|
||
with pytest.raises(AzureDataFactoryPipelineRunException): | ||
self.OPERATOR.execute(context=create_context(self.OPERATOR)) | ||
assert not mock_defer.called | ||
|
||
@pytest.mark.parametrize("status", AzureDataFactoryPipelineRunStatus.INTERMEDIATE_STATES) | ||
@mock.patch( | ||
"airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.get_pipeline_run_status", | ||
) | ||
@mock.patch("airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook.run_pipeline") | ||
def test_azure_data_factory_run_pipeline_operator_async(self, mock_run_pipeline, mock_get_status, status): | ||
"""Assert that AzureDataFactoryRunPipelineOperatorAsync deferred""" | ||
|
||
class CreateRunResponse: | ||
pass | ||
|
||
CreateRunResponse.run_id = AZ_PIPELINE_RUN_ID | ||
mock_run_pipeline.return_value = CreateRunResponse | ||
|
||
with pytest.raises(TaskDeferred) as exc: | ||
self.OPERATOR.execute(context=create_context(self.OPERATOR)) | ||
|
||
assert isinstance( | ||
exc.value.trigger, AzureDataFactoryTrigger | ||
), "Trigger is not a AzureDataFactoryTrigger" | ||
|
||
def test_azure_data_factory_run_pipeline_operator_async_execute_complete_success(self): | ||
"""Assert that execute_complete log success message""" | ||
|
||
with mock.patch.object(self.OPERATOR.log, "info") as mock_log_info: | ||
self.OPERATOR.execute_complete( | ||
context=create_context(self.OPERATOR), | ||
event={"status": "success", "message": "success", "run_id": AZ_PIPELINE_RUN_ID}, | ||
) | ||
mock_log_info.assert_called_with("success") | ||
|
||
def test_azure_data_factory_run_pipeline_operator_async_execute_complete_fail(self): | ||
"""Assert that execute_complete raise exception on error""" | ||
|
||
with pytest.raises(AirflowException): | ||
self.OPERATOR.execute_complete( | ||
context=create_context(self.OPERATOR), | ||
event={"status": "error", "message": "error", "run_id": AZ_PIPELINE_RUN_ID}, | ||
) | ||
def test_init(self): | ||
task = AzureDataFactoryRunPipelineOperatorAsync( | ||
task_id="run_pipeline", | ||
pipeline_name="pipeline", | ||
parameters={"myParam": "value"}, | ||
factory_name="factory_name", | ||
resource_group_name="resource_group", | ||
) | ||
|
||
assert isinstance(task, AzureDataFactoryRunPipelineOperator) | ||
assert task.deferrable is True |