From 1c3fb4f29c6739945ed0484b71ce0a12494fde4d Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 22 Dec 2023 11:09:59 +0800 Subject: [PATCH] Revert "Deprecate S3KeysUnchangedSensorAsync (#1392)" This reverts commit d6cbd8ea5fa0b660e6f734c32914df4950ac02e6. --- astronomer/providers/amazon/aws/sensors/s3.py | 80 +++++++++++--- .../providers/amazon/aws/triggers/s3.py | 101 ++++++++++++++++++ tests/amazon/aws/sensors/test_s3_sensors.py | 78 +++++++++++++- tests/amazon/aws/triggers/test_s3_triggers.py | 78 ++++++++++++++ 4 files changed, 318 insertions(+), 19 deletions(-) diff --git a/astronomer/providers/amazon/aws/sensors/s3.py b/astronomer/providers/amazon/aws/sensors/s3.py index a4e902060..94ec86cc8 100644 --- a/astronomer/providers/amazon/aws/sensors/s3.py +++ b/astronomer/providers/amazon/aws/sensors/s3.py @@ -9,9 +9,10 @@ from airflow.sensors.base import BaseSensorOperator from astronomer.providers.amazon.aws.triggers.s3 import ( + S3KeysUnchangedTrigger, S3KeyTrigger, ) -from astronomer.providers.utils.sensor_util import raise_error_or_skip_exception +from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception from astronomer.providers.utils.typing_compat import Context @@ -154,21 +155,72 @@ def __init__( class S3KeysUnchangedSensorAsync(S3KeysUnchangedSensor): """ - This class is deprecated. - Please use :class: `~airflow.providers.amazon.aws.sensors.s3.S3KeysUnchangedSensor`. + Checks for changes in the number of objects at prefix in AWS S3 + bucket and returns True if the inactivity period has passed with no + increase in the number of objects. Note, this sensor will not behave correctly + in reschedule mode, as the state of the listed objects in the S3 bucket will + be lost between rescheduled invocations. + + :param bucket_name: Name of the S3 bucket + :param prefix: The prefix being waited on. Relative path from bucket root level. + :param aws_conn_id: a reference to the s3 connection + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :param inactivity_period: The total seconds of inactivity to designate + keys unchanged. Note, this mechanism is not real time and + this operator may not return until a poke_interval after this period + has passed with no additional objects sensed. + :param min_objects: The minimum number of objects needed for keys unchanged + sensor to be considered valid. + :param previous_objects: The set of object ids found during the last poke. + :param allow_delete: Should this sensor consider objects being deleted + between pokes valid behavior. If true a warning message will be logged + when this happens. If false an error will be raised. """ - def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] - warnings.warn( - ( - "This module is deprecated. " - "Please use `airflow.providers.amazon.aws.sensors.s3.S3KeysUnchangedSensor` " - "and set deferrable to True instead." - ), - DeprecationWarning, - stacklevel=2, - ) - return super().__init__(*args, deferrable=True, **kwargs) + def __init__( + self, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + def execute(self, context: Context) -> None: + """Defers Trigger class to check for changes in the number of objects at prefix in AWS S3""" + if not poke(self, context): + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=S3KeysUnchangedTrigger( + bucket_name=self.bucket_name, + prefix=self.prefix, + inactivity_period=self.inactivity_period, + min_objects=self.min_objects, + previous_objects=self.previous_objects, + inactivity_seconds=self.inactivity_seconds, + allow_delete=self.allow_delete, + aws_conn_id=self.aws_conn_id, + verify=self.verify, + last_activity_time=self.last_activity_time, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: Any = None) -> None: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if event["status"] == "error": + raise_error_or_skip_exception(self.soft_fail, event["message"]) + return None class S3PrefixSensorAsync(BaseSensorOperator): diff --git a/astronomer/providers/amazon/aws/triggers/s3.py b/astronomer/providers/amazon/aws/triggers/s3.py index 20f8157c0..30a1957a8 100644 --- a/astronomer/providers/amazon/aws/triggers/s3.py +++ b/astronomer/providers/amazon/aws/triggers/s3.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from datetime import datetime from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -90,3 +91,103 @@ async def run(self) -> AsyncIterator[TriggerEvent]: def _get_async_hook(self) -> S3HookAsync: return S3HookAsync(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify")) + + +class S3KeysUnchangedTrigger(BaseTrigger): + """ + S3KeyTrigger is fired as deferred class with params to run the task in trigger worker + + :param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key`` + is not provided as a full s3:// url. + :param prefix: The prefix being waited on. Relative path from bucket root level. + :param inactivity_period: The total seconds of inactivity to designate + keys unchanged. Note, this mechanism is not real time and + this operator may not return until a poke_interval after this period + has passed with no additional objects sensed. + :param min_objects: The minimum number of objects needed for keys unchanged + sensor to be considered valid. + :param inactivity_seconds: reference to the seconds of inactivity + :param previous_objects: The set of object ids found during the last poke. + :param allow_delete: Should this sensor consider objects being deleted + :param aws_conn_id: reference to the s3 connection + :param last_activity_time: last modified or last active time + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + """ + + def __init__( + self, + bucket_name: str, + prefix: str, + inactivity_period: float = 60 * 60, + min_objects: int = 1, + inactivity_seconds: int = 0, + previous_objects: set[str] | None = None, + allow_delete: bool = True, + aws_conn_id: str = "aws_default", + last_activity_time: datetime | None = None, + verify: bool | str | None = None, + ): + super().__init__() + self.bucket_name = bucket_name + self.prefix = prefix + if inactivity_period < 0: + raise ValueError("inactivity_period must be non-negative") + if previous_objects is None: + previous_objects = set() + self.inactivity_period = inactivity_period + self.min_objects = min_objects + self.previous_objects = previous_objects + self.inactivity_seconds = inactivity_seconds + self.allow_delete = allow_delete + self.aws_conn_id = aws_conn_id + self.last_activity_time: datetime | None = last_activity_time + self.verify = verify + self.polling_period_seconds = 0 + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize S3KeysUnchangedTrigger arguments and classpath.""" + return ( + "astronomer.providers.amazon.aws.triggers.s3.S3KeysUnchangedTrigger", + { + "bucket_name": self.bucket_name, + "prefix": self.prefix, + "inactivity_period": self.inactivity_period, + "min_objects": self.min_objects, + "previous_objects": self.previous_objects, + "inactivity_seconds": self.inactivity_seconds, + "allow_delete": self.allow_delete, + "aws_conn_id": self.aws_conn_id, + "last_activity_time": self.last_activity_time, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Make an asynchronous connection using S3HookAsync.""" + try: + hook = self._get_async_hook() + async with await hook.get_client_async() as client: + while True: + result = await hook.is_keys_unchanged( + client, + self.bucket_name, + self.prefix, + self.inactivity_period, + self.min_objects, + self.previous_objects, + self.inactivity_seconds, + self.allow_delete, + self.last_activity_time, + ) + if result.get("status") == "success" or result.get("status") == "error": + yield TriggerEvent(result) + elif result.get("status") == "pending": + self.previous_objects = result.get("previous_objects", set()) + self.last_activity_time = result.get("last_activity_time") + self.inactivity_seconds = result.get("inactivity_seconds", 0) + await asyncio.sleep(self.polling_period_seconds) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) + + def _get_async_hook(self) -> S3HookAsync: + return S3HookAsync(aws_conn_id=self.aws_conn_id, verify=self.verify) diff --git a/tests/amazon/aws/sensors/test_s3_sensors.py b/tests/amazon/aws/sensors/test_s3_sensors.py index 4dfae42d6..8ba3047ad 100644 --- a/tests/amazon/aws/sensors/test_s3_sensors.py +++ b/tests/amazon/aws/sensors/test_s3_sensors.py @@ -7,7 +7,6 @@ from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred from airflow.models import DAG, DagRun, TaskInstance from airflow.models.variable import Variable -from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor from airflow.utils import timezone from parameterized import parameterized @@ -18,6 +17,7 @@ S3PrefixSensorAsync, ) from astronomer.providers.amazon.aws.triggers.s3 import ( + S3KeysUnchangedTrigger, S3KeyTrigger, ) @@ -293,12 +293,80 @@ def test_soft_fail_enable(self, context): class TestS3KeysUnchangedSensorAsync: - def test_init(self): - task = S3KeysUnchangedSensorAsync( + @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.defer") + @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=True) + def test_s3_keys_unchanged_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context): + """Assert task is not deferred when it receives a finish status before deferring""" + S3KeysUnchangedSensorAsync( + task_id="s3_keys_unchanged_sensor", bucket_name="test_bucket", prefix="test" + ) + assert not mock_defer.called + + @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False) + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") + def test_s3_keys_unchanged_sensor_check_trigger_instance(self, mock_hook, mock_poke, context): + """ + Asserts that a task is deferred and an S3KeysUnchangedTrigger will be fired + when the S3KeysUnchangedSensorAsync is executed. + """ + mock_hook.check_for_key.return_value = False + + sensor = S3KeysUnchangedSensorAsync( task_id="s3_keys_unchanged_sensor", bucket_name="test_bucket", prefix="test" ) - assert isinstance(task, S3KeysUnchangedSensor) - assert task.deferrable is True + + with pytest.raises(TaskDeferred) as exc: + sensor.execute(context) + + assert isinstance( + exc.value.trigger, S3KeysUnchangedTrigger + ), "Trigger is not a S3KeysUnchangedTrigger" + + @parameterized.expand([["bucket", "test"]]) + @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False) + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") + def test_s3_keys_unchanged_sensor_execute_complete_success(self, bucket, prefix, mock_hook, mock_poke): + """ + Asserts that a task completed with success status + """ + mock_hook.check_for_key.return_value = False + + sensor = S3KeysUnchangedSensorAsync( + task_id="s3_keys_unchanged_sensor", + bucket_name=bucket, + prefix=prefix, + ) + assert sensor.execute_complete(context={}, event={"status": "success"}) is None + + @parameterized.expand([["bucket", "test"]]) + @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False) + @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") + def test_s3_keys_unchanged_sensor_execute_complete_error(self, bucket, prefix, mock_hook, mock_poke): + """ + Asserts that a task is completed with error. + """ + mock_hook.check_for_key.return_value = False + + sensor = S3KeysUnchangedSensorAsync( + task_id="s3_keys_unchanged_sensor", + bucket_name=bucket, + prefix=prefix, + ) + with pytest.raises(AirflowException): + sensor.execute_complete(context={}, event={"status": "error", "message": "Mocked error"}) + + @mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False) + def test_s3_keys_unchanged_sensor_raise_value_error(self, mock_poke): + """ + Test if the S3KeysUnchangedTrigger raises Value error for negative inactivity_period. + """ + with pytest.raises(ValueError): + S3KeysUnchangedSensorAsync( + task_id="s3_keys_unchanged_sensor", + bucket_name="test_bucket", + prefix="test", + inactivity_period=-100, + ) class TestS3KeySizeSensorAsync(unittest.TestCase): diff --git a/tests/amazon/aws/triggers/test_s3_triggers.py b/tests/amazon/aws/triggers/test_s3_triggers.py index a1ca8c2a7..9cd043e02 100644 --- a/tests/amazon/aws/triggers/test_s3_triggers.py +++ b/tests/amazon/aws/triggers/test_s3_triggers.py @@ -1,10 +1,12 @@ import asyncio +from datetime import datetime from unittest import mock import pytest from airflow.triggers.base import TriggerEvent from astronomer.providers.amazon.aws.triggers.s3 import ( + S3KeysUnchangedTrigger, S3KeyTrigger, ) @@ -99,3 +101,79 @@ async def test_run_check_fn_success(self, mock_get_files, mock_client): generator = trigger.run() actual = await generator.asend(None) assert TriggerEvent({"status": "running", "files": [{"Size": 123}]}) == actual + + +class TestS3KeysUnchangedTrigger: + def test_serialization(self): + """ + Asserts that the TaskStateTrigger correctly serializes its arguments + and classpath. + """ + trigger = S3KeysUnchangedTrigger( + bucket_name="test_bucket", + prefix="test", + inactivity_period=1, + min_objects=1, + inactivity_seconds=0, + previous_objects=None, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "astronomer.providers.amazon.aws.triggers.s3.S3KeysUnchangedTrigger" + assert kwargs == { + "bucket_name": "test_bucket", + "prefix": "test", + "inactivity_period": 1, + "min_objects": 1, + "inactivity_seconds": 0, + "previous_objects": set(), + "allow_delete": 1, + "aws_conn_id": "aws_default", + "last_activity_time": None, + } + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.get_client_async") + async def test_run_wait(self, mock_client): + """Test if the task is run is in trigger successfully.""" + mock_client.return_value.check_key.return_value = True + trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") + with mock_client: + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + assert task.done() is True + asyncio.get_event_loop().stop() + + def test_run_raise_value_error(self): + """ + Test if the S3KeysUnchangedTrigger raises Value error for negative inactivity_period. + """ + with pytest.raises(ValueError): + S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test", inactivity_period=-100) + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.get_client_async") + @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.is_keys_unchanged") + async def test_run_success(self, mock_is_keys_unchanged, mock_client): + """ + Test if the task is run is in triggerer successfully. + """ + mock_is_keys_unchanged.return_value = {"status": "success"} + trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent({"status": "success"}) == actual + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.get_client_async") + @mock.patch("astronomer.providers.amazon.aws.triggers.s3.S3HookAsync.is_keys_unchanged") + async def test_run_pending(self, mock_is_keys_unchanged, mock_client): + """Test if the task is run is in triggerer successfully.""" + mock_is_keys_unchanged.return_value = {"status": "pending", "last_activity_time": datetime.now()} + trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop()