Skip to content

Commit

Permalink
Merge branch 'main' into deprecate-EmrContainerOperatorAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W authored Dec 22, 2023
2 parents 6671a06 + e3f50bd commit 5f1789f
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 22 deletions.
2 changes: 1 addition & 1 deletion astronomer/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
) -> None:
self.redshift_conn_id = redshift_conn_id
self.poll_interval = poll_interval
if self.__class__.__base__.__name__ == "RedshiftSQLOperator":
if self.__class__.__base__.__name__ == "RedshiftSQLOperator": # type: ignore[union-attr]
# It's better to do str check of the parent class name because currently RedshiftSQLOperator
# is deprecated and in future OSS RedshiftSQLOperator may be removed
super().__init__(**kwargs)
Expand Down
80 changes: 66 additions & 14 deletions astronomer/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from airflow.sensors.base import BaseSensorOperator

from astronomer.providers.amazon.aws.triggers.s3 import (
S3KeysUnchangedTrigger,
S3KeyTrigger,
)
from astronomer.providers.utils.sensor_util import raise_error_or_skip_exception
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


Expand Down Expand Up @@ -154,21 +155,72 @@ def __init__(

class S3KeysUnchangedSensorAsync(S3KeysUnchangedSensor):
"""
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.sensors.s3.S3KeysUnchangedSensor`.
Checks for changes in the number of objects at prefix in AWS S3
bucket and returns True if the inactivity period has passed with no
increase in the number of objects. Note, this sensor will not behave correctly
in reschedule mode, as the state of the listed objects in the S3 bucket will
be lost between rescheduled invocations.
:param bucket_name: Name of the S3 bucket
:param prefix: The prefix being waited on. Relative path from bucket root level.
:param aws_conn_id: a reference to the s3 connection
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:
- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:param inactivity_period: The total seconds of inactivity to designate
keys unchanged. Note, this mechanism is not real time and
this operator may not return until a poke_interval after this period
has passed with no additional objects sensed.
:param min_objects: The minimum number of objects needed for keys unchanged
sensor to be considered valid.
:param previous_objects: The set of object ids found during the last poke.
:param allow_delete: Should this sensor consider objects being deleted
between pokes valid behavior. If true a warning message will be logged
when this happens. If false an error will be raised.
"""

def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
warnings.warn(
(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.sensors.s3.S3KeysUnchangedSensor` "
"and set deferrable to True instead."
),
DeprecationWarning,
stacklevel=2,
)
return super().__init__(*args, deferrable=True, **kwargs)
def __init__(
self,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
"""Defers Trigger class to check for changes in the number of objects at prefix in AWS S3"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=S3KeysUnchangedTrigger(
bucket_name=self.bucket_name,
prefix=self.prefix,
inactivity_period=self.inactivity_period,
min_objects=self.min_objects,
previous_objects=self.previous_objects,
inactivity_seconds=self.inactivity_seconds,
allow_delete=self.allow_delete,
aws_conn_id=self.aws_conn_id,
verify=self.verify,
last_activity_time=self.last_activity_time,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
return None


class S3PrefixSensorAsync(BaseSensorOperator):
Expand Down
101 changes: 101 additions & 0 deletions astronomer/providers/amazon/aws/triggers/s3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
from datetime import datetime
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand Down Expand Up @@ -90,3 +91,103 @@ async def run(self) -> AsyncIterator[TriggerEvent]:

def _get_async_hook(self) -> S3HookAsync:
return S3HookAsync(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify"))


class S3KeysUnchangedTrigger(BaseTrigger):
"""
S3KeyTrigger is fired as deferred class with params to run the task in trigger worker
:param bucket_name: Name of the S3 bucket. Only needed when ``bucket_key``
is not provided as a full s3:// url.
:param prefix: The prefix being waited on. Relative path from bucket root level.
:param inactivity_period: The total seconds of inactivity to designate
keys unchanged. Note, this mechanism is not real time and
this operator may not return until a poke_interval after this period
has passed with no additional objects sensed.
:param min_objects: The minimum number of objects needed for keys unchanged
sensor to be considered valid.
:param inactivity_seconds: reference to the seconds of inactivity
:param previous_objects: The set of object ids found during the last poke.
:param allow_delete: Should this sensor consider objects being deleted
:param aws_conn_id: reference to the s3 connection
:param last_activity_time: last modified or last active time
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
"""

def __init__(
self,
bucket_name: str,
prefix: str,
inactivity_period: float = 60 * 60,
min_objects: int = 1,
inactivity_seconds: int = 0,
previous_objects: set[str] | None = None,
allow_delete: bool = True,
aws_conn_id: str = "aws_default",
last_activity_time: datetime | None = None,
verify: bool | str | None = None,
):
super().__init__()
self.bucket_name = bucket_name
self.prefix = prefix
if inactivity_period < 0:
raise ValueError("inactivity_period must be non-negative")
if previous_objects is None:
previous_objects = set()
self.inactivity_period = inactivity_period
self.min_objects = min_objects
self.previous_objects = previous_objects
self.inactivity_seconds = inactivity_seconds
self.allow_delete = allow_delete
self.aws_conn_id = aws_conn_id
self.last_activity_time: datetime | None = last_activity_time
self.verify = verify
self.polling_period_seconds = 0

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize S3KeysUnchangedTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.s3.S3KeysUnchangedTrigger",
{
"bucket_name": self.bucket_name,
"prefix": self.prefix,
"inactivity_period": self.inactivity_period,
"min_objects": self.min_objects,
"previous_objects": self.previous_objects,
"inactivity_seconds": self.inactivity_seconds,
"allow_delete": self.allow_delete,
"aws_conn_id": self.aws_conn_id,
"last_activity_time": self.last_activity_time,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make an asynchronous connection using S3HookAsync."""
try:
hook = self._get_async_hook()
async with await hook.get_client_async() as client:
while True:
result = await hook.is_keys_unchanged(
client,
self.bucket_name,
self.prefix,
self.inactivity_period,
self.min_objects,
self.previous_objects,
self.inactivity_seconds,
self.allow_delete,
self.last_activity_time,
)
if result.get("status") == "success" or result.get("status") == "error":
yield TriggerEvent(result)
elif result.get("status") == "pending":
self.previous_objects = result.get("previous_objects", set())
self.last_activity_time = result.get("last_activity_time")
self.inactivity_seconds = result.get("inactivity_seconds", 0)
await asyncio.sleep(self.polling_period_seconds)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})

def _get_async_hook(self) -> S3HookAsync:
return S3HookAsync(aws_conn_id=self.aws_conn_id, verify=self.verify)
4 changes: 2 additions & 2 deletions astronomer/providers/snowflake/operators/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
self.authenticator = authenticator
self.session_parameters = session_parameters
self.snowflake_conn_id = snowflake_conn_id
if self.__class__.__base__.__name__ != "SnowflakeOperator":
if self.__class__.__base__.__name__ != "SnowflakeOperator": # type: ignore[union-attr]
# It's better to do str check of the parent class name because currently SnowflakeOperator
# is deprecated and in future OSS SnowflakeOperator may be removed
if any(
Expand Down Expand Up @@ -319,7 +319,7 @@ def __init__(
self.token_renewal_delta = token_renewal_delta
self.bindings = bindings
self.execute_async = False
if self.__class__.__base__.__name__ != "SnowflakeOperator":
if self.__class__.__base__.__name__ != "SnowflakeOperator": # type: ignore[union-attr]
# It's better to do str check of the parent class name because currently SnowflakeOperator
# is deprecated and in future OSS SnowflakeOperator may be removed
if any(
Expand Down
78 changes: 73 additions & 5 deletions tests/amazon/aws/sensors/test_s3_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
from airflow.models.variable import Variable
from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor
from airflow.utils import timezone
from parameterized import parameterized

Expand All @@ -18,6 +17,7 @@
S3PrefixSensorAsync,
)
from astronomer.providers.amazon.aws.triggers.s3 import (
S3KeysUnchangedTrigger,
S3KeyTrigger,
)

Expand Down Expand Up @@ -293,12 +293,80 @@ def test_soft_fail_enable(self, context):


class TestS3KeysUnchangedSensorAsync:
def test_init(self):
task = S3KeysUnchangedSensorAsync(
@mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.defer")
@mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=True)
def test_s3_keys_unchanged_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
S3KeysUnchangedSensorAsync(
task_id="s3_keys_unchanged_sensor", bucket_name="test_bucket", prefix="test"
)
assert not mock_defer.called

@mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False)
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
def test_s3_keys_unchanged_sensor_check_trigger_instance(self, mock_hook, mock_poke, context):
"""
Asserts that a task is deferred and an S3KeysUnchangedTrigger will be fired
when the S3KeysUnchangedSensorAsync is executed.
"""
mock_hook.check_for_key.return_value = False

sensor = S3KeysUnchangedSensorAsync(
task_id="s3_keys_unchanged_sensor", bucket_name="test_bucket", prefix="test"
)
assert isinstance(task, S3KeysUnchangedSensor)
assert task.deferrable is True

with pytest.raises(TaskDeferred) as exc:
sensor.execute(context)

assert isinstance(
exc.value.trigger, S3KeysUnchangedTrigger
), "Trigger is not a S3KeysUnchangedTrigger"

@parameterized.expand([["bucket", "test"]])
@mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False)
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
def test_s3_keys_unchanged_sensor_execute_complete_success(self, bucket, prefix, mock_hook, mock_poke):
"""
Asserts that a task completed with success status
"""
mock_hook.check_for_key.return_value = False

sensor = S3KeysUnchangedSensorAsync(
task_id="s3_keys_unchanged_sensor",
bucket_name=bucket,
prefix=prefix,
)
assert sensor.execute_complete(context={}, event={"status": "success"}) is None

@parameterized.expand([["bucket", "test"]])
@mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False)
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
def test_s3_keys_unchanged_sensor_execute_complete_error(self, bucket, prefix, mock_hook, mock_poke):
"""
Asserts that a task is completed with error.
"""
mock_hook.check_for_key.return_value = False

sensor = S3KeysUnchangedSensorAsync(
task_id="s3_keys_unchanged_sensor",
bucket_name=bucket,
prefix=prefix,
)
with pytest.raises(AirflowException):
sensor.execute_complete(context={}, event={"status": "error", "message": "Mocked error"})

@mock.patch(f"{MODULE}.S3KeysUnchangedSensorAsync.poke", return_value=False)
def test_s3_keys_unchanged_sensor_raise_value_error(self, mock_poke):
"""
Test if the S3KeysUnchangedTrigger raises Value error for negative inactivity_period.
"""
with pytest.raises(ValueError):
S3KeysUnchangedSensorAsync(
task_id="s3_keys_unchanged_sensor",
bucket_name="test_bucket",
prefix="test",
inactivity_period=-100,
)


class TestS3KeySizeSensorAsync(unittest.TestCase):
Expand Down
Loading

0 comments on commit 5f1789f

Please sign in to comment.