diff --git a/airflow/providers/microsoft/azure/CHANGELOG.rst b/airflow/providers/microsoft/azure/CHANGELOG.rst index f2bdd9f6de0cc..b355c60f8133d 100644 --- a/airflow/providers/microsoft/azure/CHANGELOG.rst +++ b/airflow/providers/microsoft/azure/CHANGELOG.rst @@ -27,6 +27,24 @@ Changelog --------- +8.0.0 +..... + +Breaking changes +~~~~~~~~~~~~~~~~ + +.. warning:: + AzureDataFactoryHook methods and AzureDataFactoryRunPipelineOperator arguments resource_group_name and factory_name is + now required instead of kwargs + +* resource_group_name and factory_name is now required argument in AzureDataFactoryHook method get_factory, update_factory, + create_factory, delete_factory, get_linked_service, delete_linked_service, get_dataset, delete_dataset, get_dataflow, + update_dataflow, create_dataflow, delete_dataflow, get_pipeline, delete_pipeline, run_pipeline, get_pipeline_run, + get_trigger, get_pipeline_run_status, cancel_pipeline_run, create_trigger, delete_trigger, start_trigger, + stop_trigger, get_adf_pipeline_run_status, cancel_pipeline_run +* resource_group_name and factory_name is now required in AzureDataFactoryRunPipelineOperator +* Remove class ``PipelineRunInfo`` from ``airflow.providers.microsoft.azure.hooks.data_factory`` + 7.0.0 ..... diff --git a/airflow/providers/microsoft/azure/hooks/data_factory.py b/airflow/providers/microsoft/azure/hooks/data_factory.py index b4516ccfd7f6b..656aa54a6a236 100644 --- a/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -26,6 +26,7 @@ TriggerResource datafactory DataFlow + DataFlowResource mgmt """ from __future__ import annotations @@ -34,10 +35,9 @@ import time import warnings from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast +from typing import IO, TYPE_CHECKING, Any, Callable, TypeVar, Union, cast from asgiref.sync import sync_to_async -from azure.core.exceptions import ServiceRequestError from azure.identity import ClientSecretCredential, DefaultAzureCredential from azure.identity.aio import ( ClientSecretCredential as AsyncClientSecretCredential, @@ -48,13 +48,12 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook -from airflow.typing_compat import TypedDict if TYPE_CHECKING: from azure.core.polling import LROPoller from azure.mgmt.datafactory.models import ( CreateRunResponse, - DataFlow, + DataFlowResource, DatasetResource, Factory, LinkedServiceResource, @@ -88,15 +87,9 @@ def bind_argument(arg, default_key): self = args[0] conn = self.get_connection(self.conn_id) extras = conn.extra_dejson - default_value = extras.get(default_key) - if not default_value and extras.get(f"extra__azure_data_factory__{default_key}"): - warnings.warn( - f"`extra__azure_data_factory__{default_key}` is deprecated in azure connection extra," - f" please use `{default_key}` instead", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) - default_value = extras.get(f"extra__azure_data_factory__{default_key}") + default_value = extras.get(default_key) or extras.get( + f"extra__azure_data_factory__{default_key}" + ) if not default_value: raise AirflowException("Could not determine the targeted data factory.") @@ -110,14 +103,6 @@ def bind_argument(arg, default_key): return wrapper -class PipelineRunInfo(TypedDict): - """Type class for the pipeline run info dictionary.""" - - run_id: str - factory_name: str | None - resource_group_name: str | None - - class AzureDataFactoryPipelineRunStatus: """Azure Data Factory pipeline operation statuses.""" @@ -127,6 +112,7 @@ class AzureDataFactoryPipelineRunStatus: FAILED = "Failed" CANCELING = "Canceling" CANCELLED = "Cancelled" + TERMINAL_STATUSES = {CANCELLED, FAILED, SUCCEEDED} INTERMEDIATE_STATES = {QUEUED, IN_PROGRESS, CANCELING} FAILURE_STATES = {FAILED, CANCELLED} @@ -148,12 +134,6 @@ def get_field(extras: dict, field_name: str, strict: bool = False): return extras[field_name] or None prefixed_name = f"{backcompat_prefix}{field_name}" if prefixed_name in extras: - warnings.warn( - f"`{prefixed_name}` is deprecated in azure connection extra," - f" please use `{field_name}` instead", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) return extras[prefixed_name] or None if strict: raise KeyError(f"Field {field_name} not found in extras") @@ -199,7 +179,7 @@ def get_ui_field_behaviour() -> dict[str, Any]: } def __init__(self, azure_data_factory_conn_id: str = default_conn_name): - self._conn: DataFactoryManagementClient = None + self._conn: DataFactoryManagementClient | None = None self.conn_id = azure_data_factory_conn_id super().__init__() @@ -235,9 +215,7 @@ def refresh_conn(self) -> DataFactoryManagementClient: return self.get_conn() @provide_targeted_factory - def get_factory( - self, resource_group_name: str | None = None, factory_name: str | None = None, **config: Any - ) -> Factory: + def get_factory(self, resource_group_name: str, factory_name: str, **config: Any) -> Factory | None: """ Get the factory. @@ -267,8 +245,9 @@ def _create_client(credential: Credentials, subscription_id: str): def update_factory( self, factory: Factory, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, + if_match: str | None = None, **config: Any, ) -> Factory: """ @@ -277,6 +256,8 @@ def update_factory( :param factory: The factory resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. + :param if_match: ETag of the factory entity. Should only be specified for update, for which it + should match existing entity or can be * for unconditional update. Default value is None. :param config: Extra parameters for the ADF client. :raise AirflowException: If the factory does not exist. :return: The factory. @@ -285,15 +266,15 @@ def update_factory( raise AirflowException(f"Factory {factory!r} does not exist.") return self.get_conn().factories.create_or_update( - resource_group_name, factory_name, factory, **config + resource_group_name, factory_name, factory, if_match, **config ) @provide_targeted_factory def create_factory( self, factory: Factory, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> Factory: """ @@ -314,9 +295,7 @@ def create_factory( ) @provide_targeted_factory - def delete_factory( - self, resource_group_name: str | None = None, factory_name: str | None = None, **config: Any - ) -> None: + def delete_factory(self, resource_group_name: str, factory_name: str, **config: Any) -> None: """ Delete the factory. @@ -330,21 +309,25 @@ def delete_factory( def get_linked_service( self, linked_service_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, + if_none_match: str | None = None, **config: Any, - ) -> LinkedServiceResource: + ) -> LinkedServiceResource | None: """ Get the linked service. :param linked_service_name: The linked service name. :param resource_group_name: The resource group name. :param factory_name: The factory name. + :param if_none_match: ETag of the linked service entity. Should only be specified for get. If + the ETag matches the existing entity tag, or if * was provided, then no content will be + returned. Default value is None. :param config: Extra parameters for the ADF client. :return: The linked service. """ return self.get_conn().linked_services.get( - resource_group_name, factory_name, linked_service_name, **config + resource_group_name, factory_name, linked_service_name, if_none_match, **config ) def _linked_service_exists(self, resource_group_name, factory_name, linked_service_name) -> bool: @@ -363,8 +346,8 @@ def update_linked_service( self, linked_service_name: str, linked_service: LinkedServiceResource, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> LinkedServiceResource: """ @@ -390,8 +373,8 @@ def create_linked_service( self, linked_service_name: str, linked_service: LinkedServiceResource, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> LinkedServiceResource: """ @@ -416,8 +399,8 @@ def create_linked_service( def delete_linked_service( self, linked_service_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> None: """ @@ -436,10 +419,10 @@ def delete_linked_service( def get_dataset( self, dataset_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, - ) -> DatasetResource: + ) -> DatasetResource | None: """ Get the dataset. @@ -465,8 +448,8 @@ def update_dataset( self, dataset_name: str, dataset: DatasetResource, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> DatasetResource: """ @@ -492,8 +475,8 @@ def create_dataset( self, dataset_name: str, dataset: DatasetResource, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> DatasetResource: """ @@ -518,8 +501,8 @@ def create_dataset( def delete_dataset( self, dataset_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> None: """ @@ -536,26 +519,32 @@ def delete_dataset( def get_dataflow( self, dataflow_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, + if_none_match: str | None = None, **config: Any, - ) -> DataFlow: + ) -> DataFlowResource: """ Get the dataflow. :param dataflow_name: The dataflow name. :param resource_group_name: The resource group name. :param factory_name: The factory name. + :param if_none_match: ETag of the data flow entity. Should only be specified for get. If the + ETag matches the existing entity tag, or if * was provided, then no content will be returned. + Default value is None. :param config: Extra parameters for the ADF client. - :return: The dataflow. + :return: The DataFlowResource. """ - return self.get_conn().data_flows.get(resource_group_name, factory_name, dataflow_name, **config) + return self.get_conn().data_flows.get( + resource_group_name, factory_name, dataflow_name, if_none_match, **config + ) def _dataflow_exists( self, dataflow_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, ) -> bool: """Return whether the dataflow already exists.""" dataflows = { @@ -569,11 +558,12 @@ def _dataflow_exists( def update_dataflow( self, dataflow_name: str, - dataflow: DataFlow, - resource_group_name: str | None = None, - factory_name: str | None = None, + dataflow: DataFlowResource | IO, + resource_group_name: str, + factory_name: str, + if_match: str | None = None, **config: Any, - ) -> DataFlow: + ) -> DataFlowResource: """ Update the dataflow. @@ -581,9 +571,11 @@ def update_dataflow( :param dataflow: The dataflow resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. + :param if_match: ETag of the data flow entity. Should only be specified for update, for which + it should match existing entity or can be * for unconditional update. Default value is None. :param config: Extra parameters for the ADF client. :raise AirflowException: If the dataset does not exist. - :return: The dataflow. + :return: DataFlowResource. """ if not self._dataflow_exists( dataflow_name, @@ -593,18 +585,19 @@ def update_dataflow( raise AirflowException(f"Dataflow {dataflow_name!r} does not exist.") return self.get_conn().data_flows.create_or_update( - resource_group_name, factory_name, dataflow_name, dataflow, **config + resource_group_name, factory_name, dataflow_name, dataflow, if_match, **config ) @provide_targeted_factory def create_dataflow( self, dataflow_name: str, - dataflow: DataFlow, - resource_group_name: str | None = None, - factory_name: str | None = None, + dataflow: DataFlowResource, + resource_group_name: str, + factory_name: str, + if_match: str | None = None, **config: Any, - ) -> DataFlow: + ) -> DataFlowResource: """ Create the dataflow. @@ -612,6 +605,8 @@ def create_dataflow( :param dataflow: The dataflow resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. + :param if_match: ETag of the factory entity. Should only be specified for update, for which it + should match existing entity or can be * for unconditional update. Default value is None. :param config: Extra parameters for the ADF client. :raise AirflowException: If the dataset already exists. :return: The dataset. @@ -620,15 +615,15 @@ def create_dataflow( raise AirflowException(f"Dataflow {dataflow_name!r} already exists.") return self.get_conn().data_flows.create_or_update( - resource_group_name, factory_name, dataflow_name, dataflow, **config + resource_group_name, factory_name, dataflow_name, dataflow, if_match, **config ) @provide_targeted_factory def delete_dataflow( self, dataflow_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> None: """ @@ -645,10 +640,10 @@ def delete_dataflow( def get_pipeline( self, pipeline_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, - ) -> PipelineResource: + ) -> PipelineResource | None: """ Get the pipeline. @@ -674,8 +669,8 @@ def update_pipeline( self, pipeline_name: str, pipeline: PipelineResource, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> PipelineResource: """ @@ -701,8 +696,8 @@ def create_pipeline( self, pipeline_name: str, pipeline: PipelineResource, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> PipelineResource: """ @@ -727,8 +722,8 @@ def create_pipeline( def delete_pipeline( self, pipeline_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> None: """ @@ -745,8 +740,8 @@ def delete_pipeline( def run_pipeline( self, pipeline_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> CreateRunResponse: """ @@ -766,8 +761,8 @@ def run_pipeline( def get_pipeline_run( self, run_id: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> PipelineRun: """ @@ -784,8 +779,8 @@ def get_pipeline_run( def get_pipeline_run_status( self, run_id: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, ) -> str: """ Get a pipeline run's current status. @@ -796,11 +791,7 @@ def get_pipeline_run_status( :return: The status of the pipeline run. """ self.log.info("Getting the status of run ID %s.", run_id) - pipeline_run_status = self.get_pipeline_run( - run_id=run_id, - factory_name=factory_name, - resource_group_name=resource_group_name, - ).status + pipeline_run_status = self.get_pipeline_run(run_id, resource_group_name, factory_name).status self.log.info("Current status of pipeline run %s: %s", run_id, pipeline_run_status) return pipeline_run_status @@ -809,8 +800,8 @@ def wait_for_pipeline_run_status( self, run_id: str, expected_statuses: str | set[str], - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, check_interval: int = 60, timeout: int = 60 * 60 * 24 * 7, ) -> bool: @@ -826,13 +817,7 @@ def wait_for_pipeline_run_status( status. :return: Boolean indicating if the pipeline run has reached the ``expected_status``. """ - pipeline_run_info = PipelineRunInfo( - run_id=run_id, - factory_name=factory_name, - resource_group_name=resource_group_name, - ) - pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info) - executed_after_token_refresh = True + pipeline_run_status = self.get_pipeline_run_status(run_id, resource_group_name, factory_name) start_time = time.monotonic() @@ -849,14 +834,7 @@ def wait_for_pipeline_run_status( # Wait to check the status of the pipeline run based on the ``check_interval`` configured. time.sleep(check_interval) - try: - pipeline_run_status = self.get_pipeline_run_status(**pipeline_run_info) - executed_after_token_refresh = True - except ServiceRequestError: - if executed_after_token_refresh: - self.refresh_conn() - else: - raise + pipeline_run_status = self.get_pipeline_run_status(run_id, resource_group_name, factory_name) return pipeline_run_status in expected_statuses @@ -864,8 +842,8 @@ def wait_for_pipeline_run_status( def cancel_pipeline_run( self, run_id: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> None: """ @@ -882,10 +860,10 @@ def cancel_pipeline_run( def get_trigger( self, trigger_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, - ) -> TriggerResource: + ) -> TriggerResource | None: """ Get the trigger. @@ -911,8 +889,9 @@ def update_trigger( self, trigger_name: str, trigger: TriggerResource, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, + if_match: str | None = None, **config: Any, ) -> TriggerResource: """ @@ -922,6 +901,8 @@ def update_trigger( :param trigger: The trigger resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. + :param if_match: ETag of the trigger entity. Should only be specified for update, for which it + should match existing entity or can be * for unconditional update. Default value is None. :param config: Extra parameters for the ADF client. :raise AirflowException: If the trigger does not exist. :return: The trigger. @@ -930,7 +911,7 @@ def update_trigger( raise AirflowException(f"Trigger {trigger_name!r} does not exist.") return self.get_conn().triggers.create_or_update( - resource_group_name, factory_name, trigger_name, trigger, **config + resource_group_name, factory_name, trigger_name, trigger, if_match, **config ) @provide_targeted_factory @@ -938,8 +919,8 @@ def create_trigger( self, trigger_name: str, trigger: TriggerResource, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> TriggerResource: """ @@ -964,8 +945,8 @@ def create_trigger( def delete_trigger( self, trigger_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> None: """ @@ -982,8 +963,8 @@ def delete_trigger( def start_trigger( self, trigger_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> LROPoller: """ @@ -1001,8 +982,8 @@ def start_trigger( def stop_trigger( self, trigger_name: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> LROPoller: """ @@ -1021,8 +1002,8 @@ def rerun_trigger( self, trigger_name: str, run_id: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> None: """ @@ -1043,8 +1024,8 @@ def cancel_trigger( self, trigger_name: str, run_id: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> None: """ @@ -1068,7 +1049,7 @@ def test_connection(self) -> tuple[bool, str]: # DataFactoryManagementClient with incorrect values but then will fail properly once items are # retrieved using the client. We need to _actually_ try to retrieve an object to properly test the # connection. - next(self.get_conn().factories.list()) + self.get_conn().factories.list() return success except StopIteration: # If the iterator returned is empty it should still be considered a successful connection since @@ -1132,7 +1113,7 @@ class AzureDataFactoryAsyncHook(AzureDataFactoryHook): default_conn_name: str = "azure_data_factory_default" def __init__(self, azure_data_factory_conn_id: str = default_conn_name): - self._async_conn: AsyncDataFactoryManagementClient = None + self._async_conn: AsyncDataFactoryManagementClient | None = None self.conn_id = azure_data_factory_conn_id super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id) @@ -1168,7 +1149,7 @@ async def get_async_conn(self) -> AsyncDataFactoryManagementClient: return self._async_conn - async def refresh_conn(self) -> AsyncDataFactoryManagementClient: + async def refresh_conn(self) -> AsyncDataFactoryManagementClient: # type: ignore[override] self._conn = None return await self.get_async_conn() @@ -1176,8 +1157,8 @@ async def refresh_conn(self) -> AsyncDataFactoryManagementClient: async def get_pipeline_run( self, run_id: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> PipelineRun: """ @@ -1193,7 +1174,7 @@ async def get_pipeline_run( return pipeline_run async def get_adf_pipeline_run_status( - self, run_id: str, resource_group_name: str | None = None, factory_name: str | None = None + self, run_id: str, resource_group_name: str, factory_name: str ) -> str: """ Connect to Azure Data Factory asynchronously and get the pipeline status by run_id. @@ -1202,20 +1183,16 @@ async def get_adf_pipeline_run_status( :param resource_group_name: The resource group name. :param factory_name: The factory name. """ - pipeline_run = await self.get_pipeline_run( - run_id=run_id, - factory_name=factory_name, - resource_group_name=resource_group_name, - ) - status: str = pipeline_run.status + pipeline_run = await self.get_pipeline_run(run_id, resource_group_name, factory_name) + status: str = cast(str, pipeline_run.status) return status @provide_targeted_factory_async async def cancel_pipeline_run( self, run_id: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> None: """ diff --git a/airflow/providers/microsoft/azure/operators/data_factory.py b/airflow/providers/microsoft/azure/operators/data_factory.py index 12962e5610228..2aac723f86c63 100644 --- a/airflow/providers/microsoft/azure/operators/data_factory.py +++ b/airflow/providers/microsoft/azure/operators/data_factory.py @@ -29,7 +29,6 @@ AzureDataFactoryHook, AzureDataFactoryPipelineRunException, AzureDataFactoryPipelineRunStatus, - PipelineRunInfo, get_field, ) from airflow.providers.microsoft.azure.triggers.data_factory import AzureDataFactoryTrigger @@ -132,9 +131,9 @@ def __init__( *, pipeline_name: str, azure_data_factory_conn_id: str = AzureDataFactoryHook.default_conn_name, + resource_group_name: str, + factory_name: str, wait_for_termination: bool = True, - resource_group_name: str | None = None, - factory_name: str | None = None, reference_pipeline_run_id: str | None = None, is_recovery: bool | None = None, start_activity_name: str | None = None, @@ -168,9 +167,9 @@ def hook(self) -> AzureDataFactoryHook: def execute(self, context: Context) -> None: self.log.info("Executing the %s pipeline.", self.pipeline_name) response = self.hook.run_pipeline( - pipeline_name=self.pipeline_name, - resource_group_name=self.resource_group_name, - factory_name=self.factory_name, + self.pipeline_name, + self.resource_group_name, + self.factory_name, reference_pipeline_run_id=self.reference_pipeline_run_id, is_recovery=self.is_recovery, start_activity_name=self.start_activity_name, @@ -188,12 +187,12 @@ def execute(self, context: Context) -> None: self.log.info("Waiting for pipeline run %s to terminate.", self.run_id) if self.hook.wait_for_pipeline_run_status( - run_id=self.run_id, - expected_statuses=AzureDataFactoryPipelineRunStatus.SUCCEEDED, + self.run_id, + AzureDataFactoryPipelineRunStatus.SUCCEEDED, + self.resource_group_name, + self.factory_name, check_interval=self.check_interval, timeout=self.timeout, - resource_group_name=self.resource_group_name, - factory_name=self.factory_name, ): self.log.info("Pipeline run %s has completed successfully.", self.run_id) else: @@ -202,12 +201,9 @@ def execute(self, context: Context) -> None: ) else: end_time = time.time() + self.timeout - pipeline_run_info = PipelineRunInfo( - run_id=self.run_id, - factory_name=self.factory_name, - resource_group_name=self.resource_group_name, + pipeline_run_status = self.hook.get_pipeline_run_status( + self.run_id, self.resource_group_name, self.factory_name ) - pipeline_run_status = self.hook.get_pipeline_run_status(**pipeline_run_info) if pipeline_run_status not in AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES: self.defer( timeout=self.execution_timeout, diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index e2822f382257e..d01ee268ef3ff 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -21,6 +21,7 @@ description: | `Microsoft Azure `__ suspended: false versions: + - 8.0.0 - 7.0.0 - 6.3.0 - 6.2.4 @@ -81,11 +82,11 @@ dependencies: - adal>=1.2.7 - azure-storage-file-datalake>=12.9.1 - azure-kusto-data>=4.1.0 + - azure-mgmt-datafactory>=2.0.0 - azure-mgmt-containerregistry>=8.0.0 # TODO: upgrade to newer versions of all the below libraries. # See issue https://github.com/apache/airflow/issues/30199 - azure-mgmt-containerinstance>=7.0.0,<9.0.0 - - azure-mgmt-datafactory>=1.0.0,<2.0 integrations: - integration-name: Microsoft Azure Batch diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py b/airflow/providers/microsoft/azure/sensors/data_factory.py index 5cede76ad9449..f5ca765e8545d 100644 --- a/airflow/providers/microsoft/azure/sensors/data_factory.py +++ b/airflow/providers/microsoft/azure/sensors/data_factory.py @@ -59,8 +59,8 @@ def __init__( *, run_id: str, azure_data_factory_conn_id: str = AzureDataFactoryHook.default_conn_name, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: diff --git a/airflow/providers/microsoft/azure/triggers/data_factory.py b/airflow/providers/microsoft/azure/triggers/data_factory.py index e087e3556d666..b550c2b1d197a 100644 --- a/airflow/providers/microsoft/azure/triggers/data_factory.py +++ b/airflow/providers/microsoft/azure/triggers/data_factory.py @@ -44,8 +44,8 @@ def __init__( run_id: str, azure_data_factory_conn_id: str, poke_interval: float, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, ): super().__init__() self.run_id = run_id @@ -128,8 +128,8 @@ def __init__( run_id: str, azure_data_factory_conn_id: str, end_time: float, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, wait_for_termination: bool = True, check_interval: int = 60, ): diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 549492a1133b5..14d7d52dd0737 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -11,6 +11,7 @@ actionCard Acyclic acyclic AddressesType +adf adhoc adls adobjects diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 028caf09ae22b..212f3089891c3 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -558,7 +558,7 @@ "azure-mgmt-containerinstance>=7.0.0,<9.0.0", "azure-mgmt-containerregistry>=8.0.0", "azure-mgmt-cosmosdb", - "azure-mgmt-datafactory>=1.0.0,<2.0", + "azure-mgmt-datafactory>=2.0.0", "azure-mgmt-datalake-store>=0.5.0", "azure-mgmt-resource>=2.2.0", "azure-mgmt-storage>=16.0.0", diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py index 508bc2ee78202..34d0bcb48180d 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_factory.py @@ -183,301 +183,186 @@ def test_get_connection_by_credential_client_secret(connection_id: str, credenti assert mock_create_client.call_args.args[1] == "subscriptionId" -@parametrize( - explicit_factory=((RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY)), - implicit_factory=((), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)), -) -def test_get_factory(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.get_factory(*user_args) +def test_get_factory(hook: AzureDataFactoryHook): + hook.get_factory(RESOURCE_GROUP, FACTORY) - hook._conn.factories.get.assert_called_with(*sdk_args) + hook._conn.factories.get.assert_called_with(RESOURCE_GROUP, FACTORY) -@parametrize( - explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)), - implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)), -) -def test_create_factory(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.create_factory(*user_args) +def test_create_factory(hook: AzureDataFactoryHook): + hook.create_factory(MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.factories.create_or_update.assert_called_with(*sdk_args) + hook._conn.factories.create_or_update.assert_called_with(RESOURCE_GROUP, FACTORY, MODEL) -@parametrize( - explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)), - implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)), -) -def test_update_factory(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_factory(hook: AzureDataFactoryHook): with patch.object(hook, "_factory_exists") as mock_factory_exists: mock_factory_exists.return_value = True - hook.update_factory(*user_args) + hook.update_factory(MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.factories.create_or_update.assert_called_with(*sdk_args) + hook._conn.factories.create_or_update.assert_called_with(RESOURCE_GROUP, FACTORY, MODEL, None) -@parametrize( - explicit_factory=((MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, MODEL)), - implicit_factory=((MODEL,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, MODEL)), -) -def test_update_factory_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_factory_non_existent(hook: AzureDataFactoryHook): with patch.object(hook, "_factory_exists") as mock_factory_exists: mock_factory_exists.return_value = False with pytest.raises(AirflowException, match=r"Factory .+ does not exist"): - hook.update_factory(*user_args) + hook.update_factory(MODEL, RESOURCE_GROUP, FACTORY) -@parametrize( - explicit_factory=((RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY)), - implicit_factory=((), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY)), -) -def test_delete_factory(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.delete_factory(*user_args) +def test_delete_factory(hook: AzureDataFactoryHook): + hook.delete_factory(RESOURCE_GROUP, FACTORY) - hook._conn.factories.delete.assert_called_with(*sdk_args) + hook._conn.factories.delete.assert_called_with(RESOURCE_GROUP, FACTORY) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_get_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.get_linked_service(*user_args) +def test_get_linked_service(hook: AzureDataFactoryHook): + hook.get_linked_service(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.linked_services.get.assert_called_with(*sdk_args) + hook._conn.linked_services.get.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, None) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_create_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.create_linked_service(*user_args) +def test_create_linked_service(hook: AzureDataFactoryHook): + hook.create_linked_service(NAME, MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.linked_services.create_or_update(*sdk_args) + hook._conn.linked_services.create_or_update(RESOURCE_GROUP, FACTORY, NAME, MODEL) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_update_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_linked_service(hook: AzureDataFactoryHook): with patch.object(hook, "_linked_service_exists") as mock_linked_service_exists: mock_linked_service_exists.return_value = True - hook.update_linked_service(*user_args) + hook.update_linked_service(NAME, MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.linked_services.create_or_update(*sdk_args) + hook._conn.linked_services.create_or_update(RESOURCE_GROUP, FACTORY, NAME, MODEL) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_update_linked_service_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_linked_service_non_existent(hook: AzureDataFactoryHook): with patch.object(hook, "_linked_service_exists") as mock_linked_service_exists: mock_linked_service_exists.return_value = False with pytest.raises(AirflowException, match=r"Linked service .+ does not exist"): - hook.update_linked_service(*user_args) + hook.update_linked_service(NAME, MODEL, RESOURCE_GROUP, FACTORY) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_delete_linked_service(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.delete_linked_service(*user_args) +def test_delete_linked_service(hook: AzureDataFactoryHook): + hook.delete_linked_service(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.linked_services.delete.assert_called_with(*sdk_args) + hook._conn.linked_services.delete.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_get_dataset(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.get_dataset(*user_args) +def test_get_dataset(hook: AzureDataFactoryHook): + hook.get_dataset(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.datasets.get.assert_called_with(*sdk_args) + hook._conn.datasets.get.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_create_dataset(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.create_dataset(*user_args) +def test_create_dataset(hook: AzureDataFactoryHook): + hook.create_dataset(NAME, MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.datasets.create_or_update.assert_called_with(*sdk_args) + hook._conn.datasets.create_or_update.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, MODEL) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_update_dataset(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_dataset(hook: AzureDataFactoryHook): with patch.object(hook, "_dataset_exists") as mock_dataset_exists: mock_dataset_exists.return_value = True - hook.update_dataset(*user_args) + hook.update_dataset(NAME, MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.datasets.create_or_update.assert_called_with(*sdk_args) + hook._conn.datasets.create_or_update.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, MODEL) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_update_dataset_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_dataset_non_existent(hook: AzureDataFactoryHook): with patch.object(hook, "_dataset_exists") as mock_dataset_exists: mock_dataset_exists.return_value = False with pytest.raises(AirflowException, match=r"Dataset .+ does not exist"): - hook.update_dataset(*user_args) + hook.update_dataset(NAME, MODEL, RESOURCE_GROUP, FACTORY) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_delete_dataset(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.delete_dataset(*user_args) +def test_delete_dataset(hook: AzureDataFactoryHook): + hook.delete_dataset(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.datasets.delete.assert_called_with(*sdk_args) + hook._conn.datasets.delete.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_get_dataflow(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.get_dataflow(*user_args) +def test_get_dataflow(hook: AzureDataFactoryHook): + hook.get_dataflow(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.data_flows.get.assert_called_with(*sdk_args) + hook._conn.data_flows.get.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, None) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_create_dataflow(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.create_dataflow(*user_args) +def test_create_dataflow(hook: AzureDataFactoryHook): + hook.create_dataflow(NAME, MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.data_flows.create_or_update.assert_called_with(*sdk_args) + hook._conn.data_flows.create_or_update.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, MODEL, None) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_update_dataflow(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_dataflow(hook: AzureDataFactoryHook): with patch.object(hook, "_dataflow_exists") as mock_dataflow_exists: mock_dataflow_exists.return_value = True - hook.update_dataflow(*user_args) + hook.update_dataflow(NAME, MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.data_flows.create_or_update.assert_called_with(*sdk_args) + hook._conn.data_flows.create_or_update.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, MODEL, None) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_update_dataflow_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_dataflow_non_existent(hook: AzureDataFactoryHook): with patch.object(hook, "_dataflow_exists") as mock_dataflow_exists: mock_dataflow_exists.return_value = False with pytest.raises(AirflowException, match=r"Dataflow .+ does not exist"): - hook.update_dataflow(*user_args) + hook.update_dataflow(NAME, MODEL, RESOURCE_GROUP, FACTORY) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=( - (NAME,), - ( - DEFAULT_RESOURCE_GROUP, - DEFAULT_FACTORY, - NAME, - ), - ), -) -def test_delete_dataflow(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.delete_dataflow(*user_args) +def test_delete_dataflow(hook: AzureDataFactoryHook): + hook.delete_dataflow(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.data_flows.delete.assert_called_with(*sdk_args) + hook._conn.data_flows.delete.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_get_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.get_pipeline(*user_args) +def test_get_pipeline(hook: AzureDataFactoryHook): + hook.get_pipeline(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.pipelines.get.assert_called_with(*sdk_args) + hook._conn.pipelines.get.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_create_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.create_pipeline(*user_args) +def test_create_pipeline(hook: AzureDataFactoryHook): + hook.create_pipeline(NAME, MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args) + hook._conn.pipelines.create_or_update.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, MODEL) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_update_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_pipeline(hook: AzureDataFactoryHook): with patch.object(hook, "_pipeline_exists") as mock_pipeline_exists: mock_pipeline_exists.return_value = True - hook.update_pipeline(*user_args) + hook.update_pipeline(NAME, MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.pipelines.create_or_update.assert_called_with(*sdk_args) + hook._conn.pipelines.create_or_update.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, MODEL) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_update_pipeline_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_pipeline_non_existent(hook: AzureDataFactoryHook): with patch.object(hook, "_pipeline_exists") as mock_pipeline_exists: mock_pipeline_exists.return_value = False with pytest.raises(AirflowException, match=r"Pipeline .+ does not exist"): - hook.update_pipeline(*user_args) + hook.update_pipeline(NAME, MODEL, RESOURCE_GROUP, FACTORY) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_delete_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.delete_pipeline(*user_args) +def test_delete_pipeline(hook: AzureDataFactoryHook): + hook.delete_pipeline(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.pipelines.delete.assert_called_with(*sdk_args) + hook._conn.pipelines.delete.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_run_pipeline(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.run_pipeline(*user_args) +def test_run_pipeline(hook: AzureDataFactoryHook): + hook.run_pipeline(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.pipelines.create_run.assert_called_with(*sdk_args) + hook._conn.pipelines.create_run.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, ID)), - implicit_factory=((ID,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, ID)), -) -def test_get_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.get_pipeline_run(*user_args) +def test_get_pipeline_run(hook: AzureDataFactoryHook): + hook.get_pipeline_run(ID, RESOURCE_GROUP, FACTORY) - hook._conn.pipeline_runs.get.assert_called_with(*sdk_args) + hook._conn.pipeline_runs.get.assert_called_with(RESOURCE_GROUP, FACTORY, ID) _wait_for_pipeline_run_status_test_args = [ @@ -504,7 +389,14 @@ def test_get_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args): ], ) def test_wait_for_pipeline_run_status(hook, pipeline_run_status, expected_status, expected_output): - config = {"run_id": ID, "timeout": 3, "check_interval": 1, "expected_statuses": expected_status} + config = { + "resource_group_name": RESOURCE_GROUP, + "factory_name": FACTORY, + "run_id": ID, + "timeout": 3, + "check_interval": 1, + "expected_statuses": expected_status, + } with patch.object(AzureDataFactoryHook, "get_pipeline_run") as mock_pipeline_run: mock_pipeline_run.return_value.status = pipeline_run_status @@ -516,108 +408,68 @@ def test_wait_for_pipeline_run_status(hook, pipeline_run_status, expected_status hook.wait_for_pipeline_run_status(**config) -@parametrize( - explicit_factory=((ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, ID)), - implicit_factory=((ID,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, ID)), -) -def test_cancel_pipeline_run(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.cancel_pipeline_run(*user_args) +def test_cancel_pipeline_run(hook: AzureDataFactoryHook): + hook.cancel_pipeline_run(ID, RESOURCE_GROUP, FACTORY) - hook._conn.pipeline_runs.cancel.assert_called_with(*sdk_args) + hook._conn.pipeline_runs.cancel.assert_called_with(RESOURCE_GROUP, FACTORY, ID) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_get_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.get_trigger(*user_args) +def test_get_trigger(hook: AzureDataFactoryHook): + hook.get_trigger(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.triggers.get.assert_called_with(*sdk_args) + hook._conn.triggers.get.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_create_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.create_trigger(*user_args) +def test_create_trigger(hook: AzureDataFactoryHook): + hook.create_trigger(NAME, MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.triggers.create_or_update.assert_called_with(*sdk_args) + hook._conn.triggers.create_or_update.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, MODEL) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_update_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_trigger(hook: AzureDataFactoryHook): with patch.object(hook, "_trigger_exists") as mock_trigger_exists: mock_trigger_exists.return_value = True - hook.update_trigger(*user_args) + hook.update_trigger(NAME, MODEL, RESOURCE_GROUP, FACTORY) - hook._conn.triggers.create_or_update.assert_called_with(*sdk_args) + hook._conn.triggers.create_or_update.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, MODEL, None) -@parametrize( - explicit_factory=((NAME, MODEL, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, MODEL)), - implicit_factory=((NAME, MODEL), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, MODEL)), -) -def test_update_trigger_non_existent(hook: AzureDataFactoryHook, user_args, sdk_args): +def test_update_trigger_non_existent(hook: AzureDataFactoryHook): with patch.object(hook, "_trigger_exists") as mock_trigger_exists: mock_trigger_exists.return_value = False with pytest.raises(AirflowException, match=r"Trigger .+ does not exist"): - hook.update_trigger(*user_args) + hook.update_trigger(NAME, MODEL, RESOURCE_GROUP, FACTORY) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_delete_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.delete_trigger(*user_args) +def test_delete_trigger(hook: AzureDataFactoryHook): + hook.delete_trigger(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.triggers.delete.assert_called_with(*sdk_args) + hook._conn.triggers.delete.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_start_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.start_trigger(*user_args) +def test_start_trigger(hook: AzureDataFactoryHook): + hook.start_trigger(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.triggers.begin_start.assert_called_with(*sdk_args) + hook._conn.triggers.begin_start.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((NAME, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME)), - implicit_factory=((NAME,), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME)), -) -def test_stop_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.stop_trigger(*user_args) +def test_stop_trigger(hook: AzureDataFactoryHook): + hook.stop_trigger(NAME, RESOURCE_GROUP, FACTORY) - hook._conn.triggers.begin_stop.assert_called_with(*sdk_args) + hook._conn.triggers.begin_stop.assert_called_with(RESOURCE_GROUP, FACTORY, NAME) -@parametrize( - explicit_factory=((NAME, ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, ID)), - implicit_factory=((NAME, ID), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, ID)), -) -def test_rerun_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.rerun_trigger(*user_args) +def test_rerun_trigger(hook: AzureDataFactoryHook): + hook.rerun_trigger(NAME, ID, RESOURCE_GROUP, FACTORY) - hook._conn.trigger_runs.rerun.assert_called_with(*sdk_args) + hook._conn.trigger_runs.rerun.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, ID) -@parametrize( - explicit_factory=((NAME, ID, RESOURCE_GROUP, FACTORY), (RESOURCE_GROUP, FACTORY, NAME, ID)), - implicit_factory=((NAME, ID), (DEFAULT_RESOURCE_GROUP, DEFAULT_FACTORY, NAME, ID)), -) -def test_cancel_trigger(hook: AzureDataFactoryHook, user_args, sdk_args): - hook.cancel_trigger(*user_args) +def test_cancel_trigger(hook: AzureDataFactoryHook): + hook.cancel_trigger(NAME, ID, RESOURCE_GROUP, FACTORY) - hook._conn.trigger_runs.cancel.assert_called_with(*sdk_args) + hook._conn.trigger_runs.cancel.assert_called_with(RESOURCE_GROUP, FACTORY, NAME, ID) @pytest.mark.parametrize( @@ -672,8 +524,8 @@ def test_connection_failure_missing_tenant_id(): def test_provide_targeted_factory_backcompat_prefix_works(mock_connect, uri): with patch.dict(os.environ, {"AIRFLOW_CONN_MY_CONN": uri}): hook = AzureDataFactoryHook("my_conn") - hook.delete_factory() - mock_connect.return_value.factories.delete.assert_called_with("abc", "abc") + hook.delete_factory(RESOURCE_GROUP, FACTORY) + mock_connect.return_value.factories.delete.assert_called_with(RESOURCE_GROUP, FACTORY) @pytest.mark.parametrize( @@ -707,8 +559,8 @@ def test_backcompat_prefix_both_prefers_short(mock_connect): }, ): hook = AzureDataFactoryHook("my_conn") - hook.delete_factory(factory_name="n/a") - mock_connect.return_value.factories.delete.assert_called_with("non-prefixed", "n/a") + hook.delete_factory(RESOURCE_GROUP, FACTORY) + mock_connect.return_value.factories.delete.assert_called_with(RESOURCE_GROUP, FACTORY) def test_refresh_conn(hook): diff --git a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py index 98ded34e1c91f..599a7b54e947b 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/operators/test_azure_data_factory.py @@ -153,9 +153,9 @@ def test_execute_wait_for_termination(self, mock_run_pipeline, pipeline_run_stat ) mock_run_pipeline.assert_called_once_with( - pipeline_name=self.config["pipeline_name"], - resource_group_name=self.config["resource_group_name"], - factory_name=self.config["factory_name"], + self.config["pipeline_name"], + self.config["resource_group_name"], + self.config["factory_name"], reference_pipeline_run_id=None, is_recovery=None, start_activity_name=None, @@ -165,9 +165,9 @@ def test_execute_wait_for_termination(self, mock_run_pipeline, pipeline_run_stat if pipeline_run_status in AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES: mock_get_pipeline_run.assert_called_once_with( - run_id=mock_run_pipeline.return_value.run_id, - factory_name=self.config["factory_name"], - resource_group_name=self.config["resource_group_name"], + mock_run_pipeline.return_value.run_id, + self.config["resource_group_name"], + self.config["factory_name"], ) else: # When the pipeline run status is not in a terminal status or "Succeeded", the operator will @@ -177,9 +177,9 @@ def test_execute_wait_for_termination(self, mock_run_pipeline, pipeline_run_stat assert mock_get_pipeline_run.call_count == 4 mock_get_pipeline_run.assert_called_with( - run_id=mock_run_pipeline.return_value.run_id, - factory_name=self.config["factory_name"], - resource_group_name=self.config["resource_group_name"], + mock_run_pipeline.return_value.run_id, + self.config["resource_group_name"], + self.config["factory_name"], ) @patch.object(AzureDataFactoryHook, "run_pipeline", return_value=MagicMock(**PIPELINE_RUN_RESPONSE)) @@ -205,9 +205,9 @@ def test_execute_no_wait_for_termination(self, mock_run_pipeline): ) mock_run_pipeline.assert_called_once_with( - pipeline_name=self.config["pipeline_name"], - resource_group_name=self.config["resource_group_name"], - factory_name=self.config["factory_name"], + self.config["pipeline_name"], + self.config["resource_group_name"], + self.config["factory_name"], reference_pipeline_run_id=None, is_recovery=None, start_activity_name=None, @@ -268,7 +268,12 @@ def test_run_pipeline_operator_link(self, resource_group, factory, create_task_i class TestAzureDataFactoryRunPipelineOperatorWithDeferrable: OPERATOR = AzureDataFactoryRunPipelineOperator( - task_id="run_pipeline", pipeline_name="pipeline", parameters={"myParam": "value"}, deferrable=True + task_id="run_pipeline", + pipeline_name="pipeline", + resource_group_name="resource-group-name", + factory_name="factory-name", + parameters={"myParam": "value"}, + deferrable=True, ) def get_dag_run(self, dag_id: str = "test_dag_id", run_id: str = "test_dag_id") -> DagRun: diff --git a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py index fb489c7ad7e6b..78631f036891b 100644 --- a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py @@ -125,7 +125,11 @@ def test_adf_pipeline_status_sensor_execute_complete_failure(self, soft_fail, ex class TestAzureDataFactoryPipelineRunStatusSensorWithAsync: RUN_ID = "7f8c6c72-c093-11ec-a83d-0242ac120007" SENSOR = AzureDataFactoryPipelineRunStatusSensor( - task_id="pipeline_run_sensor_async", run_id=RUN_ID, deferrable=True + task_id="pipeline_run_sensor_async", + run_id=RUN_ID, + resource_group_name="resource-group-name", + factory_name="factory-name", + deferrable=True, ) @mock.patch("airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryHook")