diff --git a/astronomer/providers/microsoft/azure/sensors/wasb.py b/astronomer/providers/microsoft/azure/sensors/wasb.py index bd56f45d2..3d6e04f70 100644 --- a/astronomer/providers/microsoft/azure/sensors/wasb.py +++ b/astronomer/providers/microsoft/azure/sensors/wasb.py @@ -1,16 +1,11 @@ import warnings -from datetime import timedelta -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from airflow.providers.microsoft.azure.sensors.wasb import ( WasbBlobSensor, WasbPrefixSensor, ) -from astronomer.providers.microsoft.azure.triggers.wasb import WasbPrefixSensorTrigger -from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception -from astronomer.providers.utils.typing_compat import Context - class WasbBlobSensorAsync(WasbBlobSensor): """ @@ -48,31 +43,28 @@ def __init__( class WasbPrefixSensorAsync(WasbPrefixSensor): """ - Polls asynchronously for the existence of a blob having the given prefix in a WASB container. - - :param container_name: name of the container in which the blob should be searched for - :param blob_name: name of the blob to check existence for - :param include: specifies one or more additional datasets to include in the - response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``, - ``copy`, ``deleted`` - :param delimiter: filters objects based on the delimiter (for e.g '.csv') - :param wasb_conn_id: the connection identifier for connecting to Azure WASB - :param poll_interval: polling period in seconds to check for the status - :param public_read: whether an anonymous public read access should be used. Default is False + This class is deprecated. + Use :class: `~airflow.providers.microsoft.azure.sensors.wasb.WasbPrefixSensor` instead + and set `deferrable` param to `True` instead. """ def __init__( self, - *, - container_name: str, - prefix: str, + *args, include: Optional[List[str]] = None, delimiter: Optional[str] = "/", - wasb_conn_id: str = "wasb_default", - public_read: bool = False, poll_interval: float = 5.0, **kwargs: Any, ): + warnings.warn( + ( + "This class is deprecated. " + "Use `airflow.providers.microsoft.azure.sensors.wasb.WasbPrefixSensor` " + "and set `deferrable` param to `True` instead." + ), + DeprecationWarning, + stacklevel=2, + ) # TODO: Remove once deprecated if poll_interval: self.poke_interval = poll_interval @@ -82,38 +74,8 @@ def __init__( DeprecationWarning, stacklevel=2, ) - super().__init__(container_name=container_name, prefix=prefix, **kwargs) - self.container_name = container_name - self.prefix = prefix - self.include = include - self.delimiter = delimiter - self.wasb_conn_id = wasb_conn_id - self.public_read = public_read - - def execute(self, context: Context) -> None: - """Defers trigger class to poll for state of the job run until it reaches a failure state or success state""" - if not poke(self, context): - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=WasbPrefixSensorTrigger( - container_name=self.container_name, - prefix=self.prefix, - include=self.include, - delimiter=self.delimiter, - wasb_conn_id=self.wasb_conn_id, - public_read=self.public_read, - poke_interval=self.poke_interval, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Context, event: Dict[str, str]) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if event["status"] == "error": - raise_error_or_skip_exception(self.soft_fail, event["message"]) - self.log.info(event["message"]) + if kwargs.get("check_options") is None: + kwargs["check_options"] = {} + kwargs["check_options"]["include"] = include + kwargs["check_options"]["delimiter"] = delimiter + super().__init__(*args, deferrable=True, **kwargs) diff --git a/astronomer/providers/microsoft/azure/triggers/wasb.py b/astronomer/providers/microsoft/azure/triggers/wasb.py index fa2160eba..4e90c4f30 100644 --- a/astronomer/providers/microsoft/azure/triggers/wasb.py +++ b/astronomer/providers/microsoft/azure/triggers/wasb.py @@ -76,18 +76,8 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: class WasbPrefixSensorTrigger(BaseTrigger): """ - WasbPrefixSensorTrigger is fired as a deferred class with params to run the task in trigger worker. - It checks for the existence of a blob with the given prefix in the provided container. - - :param container_name: name of the container in which the blob should be searched for - :param prefix: prefix of the blob to check existence for - :param include: specifies one or more additional datasets to include in the - response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``, - ``copy`, ``deleted`` - :param delimiter: filters objects based on the delimiter (for e.g '.csv') - :param wasb_conn_id: the connection identifier for connecting to Azure WASB - :param poke_interval: polling period in seconds to check for the status - :param public_read: whether an anonymous public read access should be used. Default is False + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.microsoft.azure.triggers.wasb.WasbPrefixSensorTrigger` instead. """ def __init__( @@ -100,6 +90,14 @@ def __init__( public_read: bool = False, poke_interval: float = 5.0, ): + warnings.warn( + ( + "This class is deprecated and will be removed in 2.0.0." + "Use :class: `~airflow.providers.microsoft.azure.triggers.wasb.WasbPrefixSensorTrigger` instead" + ), + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.container_name = container_name self.prefix = prefix diff --git a/setup.cfg b/setup.cfg index 868385db7..6d28d83dc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -127,6 +127,7 @@ all = apache-airflow-providers-http apache-airflow-providers-snowflake apache-airflow-providers-sftp + # TODO: Increment microsoft-azure version as per the upcoming release for deprecating WasbPrefixSensorAsync apache-airflow-providers-microsoft-azure>=8.5.1 asyncssh>=2.12.0 databricks-sql-connector>=2.0.4;python_version>='3.10' diff --git a/tests/microsoft/azure/sensors/test_wasb.py b/tests/microsoft/azure/sensors/test_wasb.py index 3b59ba043..45ec8b5c6 100644 --- a/tests/microsoft/azure/sensors/test_wasb.py +++ b/tests/microsoft/azure/sensors/test_wasb.py @@ -1,15 +1,10 @@ -from unittest import mock - import pytest -from airflow.exceptions import AirflowException, TaskDeferred -from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor +from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor, WasbPrefixSensor from astronomer.providers.microsoft.azure.sensors.wasb import ( WasbBlobSensorAsync, WasbPrefixSensorAsync, ) -from astronomer.providers.microsoft.azure.triggers.wasb import WasbPrefixSensorTrigger -from tests.utils.airflow_util import create_context TEST_DATA_STORAGE_BLOB_NAME = "test_blob_providers_team.txt" TEST_DATA_STORAGE_CONTAINER_NAME = "test-container-providers-team" @@ -29,48 +24,28 @@ def test_init(self): assert isinstance(task, WasbBlobSensor) assert task.deferrable is True + def test_poll_interval_deprecation_warning_wasb_blob(self): + """Test DeprecationWarning for WasbBlobSensorAsync by setting param poll_interval""" + # TODO: Remove once deprecated + with pytest.warns(expected_warning=DeprecationWarning): + WasbBlobSensorAsync( + task_id="wasb_blob_sensor_async", + container_name=TEST_DATA_STORAGE_CONTAINER_NAME, + blob_name=TEST_DATA_STORAGE_BLOB_NAME, + poll_interval=5.0, + ) -class TestWasbPrefixSensorAsync: - SENSOR = WasbPrefixSensorAsync( - task_id="wasb_prefix_sensor_async", - container_name=TEST_DATA_STORAGE_CONTAINER_NAME, - prefix=TEST_DATA_STORAGE_BLOB_PREFIX, - ) - - @mock.patch(f"{MODULE}.WasbPrefixSensorAsync.defer") - @mock.patch(f"{MODULE}.WasbPrefixSensorAsync.poke", return_value=True) - def test_wasb_prefix_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context): - """Assert task is not deferred when it receives a finish status before deferring""" - self.SENSOR.execute(create_context(self.SENSOR)) - - assert not mock_defer.called - - @mock.patch(f"{MODULE}.WasbPrefixSensorAsync.poke", return_value=False) - def test_wasb_prefix_sensor_async(self, mock_poke): - """Assert execute method defer for wasb prefix sensor""" - - with pytest.raises(TaskDeferred) as exc: - self.SENSOR.execute(create_context(self.SENSOR)) - assert isinstance( - exc.value.trigger, WasbPrefixSensorTrigger - ), "Trigger is not a WasbPrefixSensorTrigger" - - @pytest.mark.parametrize( - "event", - [{"status": "success", "message": "Job completed"}], - ) - def test_wasb_prefix_sensor_execute_complete_success(self, event): - """Assert execute_complete log success message when trigger fire with target status.""" - - with mock.patch.object(self.SENSOR.log, "info") as mock_log_info: - self.SENSOR.execute_complete(context={}, event=event) - mock_log_info.assert_called_with(event["message"]) - def test_wasb_prefix_sensor_execute_complete_failure(self): - """Assert execute_complete method raises an exception when the triggerer fires an error event.""" +class TestWasbPrefixSensorAsync: + def test_init(self): + task = WasbPrefixSensorAsync( + task_id="wasb_prefix_sensor_async", + container_name=TEST_DATA_STORAGE_CONTAINER_NAME, + prefix=TEST_DATA_STORAGE_BLOB_PREFIX, + ) - with pytest.raises(AirflowException): - self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""}) + assert isinstance(task, WasbPrefixSensor) + assert task.deferrable is True def test_poll_interval_deprecation_warning(self): """Test DeprecationWarning for WasbPrefixSensorAsync by setting param poll_interval"""