Skip to content

Commit

Permalink
Deprecate WasbPrefixSensorAsync
Browse files Browse the repository at this point in the history
Deprecate WasbPrefixSensorAsync and proxy it to its Airflow OSS
provider's counterpart

related: #1412
  • Loading branch information
pankajkoti committed Jan 19, 2024
1 parent 52aaa36 commit f7c4fed
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 114 deletions.
76 changes: 19 additions & 57 deletions astronomer/providers/microsoft/azure/sensors/wasb.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import warnings
from datetime import timedelta
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional

from airflow.providers.microsoft.azure.sensors.wasb import (
WasbBlobSensor,
WasbPrefixSensor,
)

from astronomer.providers.microsoft.azure.triggers.wasb import WasbPrefixSensorTrigger
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


class WasbBlobSensorAsync(WasbBlobSensor):
"""
Expand Down Expand Up @@ -48,31 +43,28 @@ def __init__(

class WasbPrefixSensorAsync(WasbPrefixSensor):
"""
Polls asynchronously for the existence of a blob having the given prefix in a WASB container.
:param container_name: name of the container in which the blob should be searched for
:param blob_name: name of the blob to check existence for
:param include: specifies one or more additional datasets to include in the
response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``,
``copy`, ``deleted``
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:param wasb_conn_id: the connection identifier for connecting to Azure WASB
:param poll_interval: polling period in seconds to check for the status
:param public_read: whether an anonymous public read access should be used. Default is False
This class is deprecated.
Use :class: `~airflow.providers.microsoft.azure.sensors.wasb.WasbPrefixSensor` instead
and set `deferrable` param to `True` instead.
"""

def __init__(
self,
*,
container_name: str,
prefix: str,
*args,
include: Optional[List[str]] = None,
delimiter: Optional[str] = "/",
wasb_conn_id: str = "wasb_default",
public_read: bool = False,
poll_interval: float = 5.0,
**kwargs: Any,
):
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.microsoft.azure.sensors.wasb.WasbPrefixSensor` "
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
# TODO: Remove once deprecated
if poll_interval:
self.poke_interval = poll_interval
Expand All @@ -82,38 +74,8 @@ def __init__(
DeprecationWarning,
stacklevel=2,
)
super().__init__(container_name=container_name, prefix=prefix, **kwargs)
self.container_name = container_name
self.prefix = prefix
self.include = include
self.delimiter = delimiter
self.wasb_conn_id = wasb_conn_id
self.public_read = public_read

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=WasbPrefixSensorTrigger(
container_name=self.container_name,
prefix=self.prefix,
include=self.include,
delimiter=self.delimiter,
wasb_conn_id=self.wasb_conn_id,
public_read=self.public_read,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: Dict[str, str]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event["message"])
if kwargs.get("check_options") is None:
kwargs["check_options"] = {}
kwargs["check_options"]["include"] = include
kwargs["check_options"]["delimiter"] = delimiter
super().__init__(*args, deferrable=True, **kwargs)
22 changes: 10 additions & 12 deletions astronomer/providers/microsoft/azure/triggers/wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,8 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:

class WasbPrefixSensorTrigger(BaseTrigger):
"""
WasbPrefixSensorTrigger is fired as a deferred class with params to run the task in trigger worker.
It checks for the existence of a blob with the given prefix in the provided container.
:param container_name: name of the container in which the blob should be searched for
:param prefix: prefix of the blob to check existence for
:param include: specifies one or more additional datasets to include in the
response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``,
``copy`, ``deleted``
:param delimiter: filters objects based on the delimiter (for e.g '.csv')
:param wasb_conn_id: the connection identifier for connecting to Azure WASB
:param poke_interval: polling period in seconds to check for the status
:param public_read: whether an anonymous public read access should be used. Default is False
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.microsoft.azure.triggers.wasb.WasbPrefixSensorTrigger` instead.
"""

def __init__(
Expand All @@ -100,6 +90,14 @@ def __init__(
public_read: bool = False,
poke_interval: float = 5.0,
):
warnings.warn(
(
"This class is deprecated and will be removed in 2.0.0."
"Use :class: `~airflow.providers.microsoft.azure.triggers.wasb.WasbPrefixSensorTrigger` instead"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.container_name = container_name
self.prefix = prefix
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ all =
apache-airflow-providers-http
apache-airflow-providers-snowflake
apache-airflow-providers-sftp
# TODO: Increment microsoft-azure version as per the upcoming release for deprecating WasbPrefixSensorAsync
apache-airflow-providers-microsoft-azure>=8.5.1
asyncssh>=2.12.0
databricks-sql-connector>=2.0.4;python_version>='3.10'
Expand Down
65 changes: 20 additions & 45 deletions tests/microsoft/azure/sensors/test_wasb.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
from unittest import mock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor
from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor, WasbPrefixSensor

from astronomer.providers.microsoft.azure.sensors.wasb import (
WasbBlobSensorAsync,
WasbPrefixSensorAsync,
)
from astronomer.providers.microsoft.azure.triggers.wasb import WasbPrefixSensorTrigger
from tests.utils.airflow_util import create_context

TEST_DATA_STORAGE_BLOB_NAME = "test_blob_providers_team.txt"
TEST_DATA_STORAGE_CONTAINER_NAME = "test-container-providers-team"
Expand All @@ -29,48 +24,28 @@ def test_init(self):
assert isinstance(task, WasbBlobSensor)
assert task.deferrable is True

def test_poll_interval_deprecation_warning_wasb_blob(self):
"""Test DeprecationWarning for WasbBlobSensorAsync by setting param poll_interval"""
# TODO: Remove once deprecated
with pytest.warns(expected_warning=DeprecationWarning):
WasbBlobSensorAsync(
task_id="wasb_blob_sensor_async",
container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
blob_name=TEST_DATA_STORAGE_BLOB_NAME,
poll_interval=5.0,
)

class TestWasbPrefixSensorAsync:
SENSOR = WasbPrefixSensorAsync(
task_id="wasb_prefix_sensor_async",
container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
prefix=TEST_DATA_STORAGE_BLOB_PREFIX,
)

@mock.patch(f"{MODULE}.WasbPrefixSensorAsync.defer")
@mock.patch(f"{MODULE}.WasbPrefixSensorAsync.poke", return_value=True)
def test_wasb_prefix_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.SENSOR.execute(create_context(self.SENSOR))

assert not mock_defer.called

@mock.patch(f"{MODULE}.WasbPrefixSensorAsync.poke", return_value=False)
def test_wasb_prefix_sensor_async(self, mock_poke):
"""Assert execute method defer for wasb prefix sensor"""

with pytest.raises(TaskDeferred) as exc:
self.SENSOR.execute(create_context(self.SENSOR))
assert isinstance(
exc.value.trigger, WasbPrefixSensorTrigger
), "Trigger is not a WasbPrefixSensorTrigger"

@pytest.mark.parametrize(
"event",
[{"status": "success", "message": "Job completed"}],
)
def test_wasb_prefix_sensor_execute_complete_success(self, event):
"""Assert execute_complete log success message when trigger fire with target status."""

with mock.patch.object(self.SENSOR.log, "info") as mock_log_info:
self.SENSOR.execute_complete(context={}, event=event)
mock_log_info.assert_called_with(event["message"])

def test_wasb_prefix_sensor_execute_complete_failure(self):
"""Assert execute_complete method raises an exception when the triggerer fires an error event."""
class TestWasbPrefixSensorAsync:
def test_init(self):
task = WasbPrefixSensorAsync(
task_id="wasb_prefix_sensor_async",
container_name=TEST_DATA_STORAGE_CONTAINER_NAME,
prefix=TEST_DATA_STORAGE_BLOB_PREFIX,
)

with pytest.raises(AirflowException):
self.SENSOR.execute_complete(context={}, event={"status": "error", "message": ""})
assert isinstance(task, WasbPrefixSensor)
assert task.deferrable is True

def test_poll_interval_deprecation_warning(self):
"""Test DeprecationWarning for WasbPrefixSensorAsync by setting param poll_interval"""
Expand Down

0 comments on commit f7c4fed

Please sign in to comment.