Skip to content

Commit

Permalink
feat(providers/amazon): deprecate RedshiftClusterSensorAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Jan 24, 2024
1 parent 6c4a395 commit 2b47fd0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 116 deletions.
63 changes: 15 additions & 48 deletions astronomer/providers/amazon/aws/sensors/redshift_cluster.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import warnings
from datetime import timedelta
from typing import Any, Dict, Optional
from typing import Any

from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor

from astronomer.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterSensorTrigger,
)
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


class RedshiftClusterSensorAsync(RedshiftClusterSensor):
"""
Waits for a Redshift cluster to reach a specific status.
:param cluster_identifier: The identifier for the cluster being pinged.\
:param target_status: The cluster status desired.
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.sensors.redshift_cluster.RedshiftClusterSensor`
and set `deferrable` param to `True` instead.
"""

def __init__(
Expand All @@ -27,45 +19,20 @@ def __init__(
):
# TODO: Remove once deprecated
if poll_interval:
self.poke_interval = poll_interval
kwargs["poke_interval"] = poll_interval
warnings.warn(
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
"""Check for the target_status and defers using the trigger"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=RedshiftClusterSensorTrigger(
task_id=self.task_id,
aws_conn_id=self.aws_conn_id,
cluster_identifier=self.cluster_identifier,
target_status=self.target_status,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: Optional[Dict[Any, Any]] = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = "{}: {}".format(event["status"], event["message"])
raise_error_or_skip_exception(self.soft_fail, msg)
if "status" in event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
self.log.info(
"Cluster Identifier %s is in %s state", self.cluster_identifier, self.target_status
)
return None
self.log.info("%s completed successfully.", self.task_id)
return None
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.sensors.redshift_cluster.RedshiftClusterSensor`"
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__(deferrable=True, **kwargs)
25 changes: 19 additions & 6 deletions astronomer/providers/amazon/aws/triggers/redshift_cluster.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
import warnings
from typing import Any, AsyncIterator, Dict, Optional, Tuple
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand Down Expand Up @@ -29,7 +31,7 @@ def __init__(
operation_type: str,
polling_period_seconds: float = 5.0,
skip_final_cluster_snapshot: bool = True,
final_cluster_snapshot_identifier: Optional[str] = None,
final_cluster_snapshot_identifier: str | None = None,
):
warnings.warn(
(
Expand All @@ -48,7 +50,7 @@ def __init__(
self.skip_final_cluster_snapshot = skip_final_cluster_snapshot
self.final_cluster_snapshot_identifier = final_cluster_snapshot_identifier

def serialize(self) -> Tuple[str, Dict[str, Any]]:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes RedshiftClusterTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger",
Expand All @@ -63,7 +65,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Make async connection to redshift, based on the operation type call
the RedshiftHookAsync functions
Expand Down Expand Up @@ -112,6 +114,9 @@ class RedshiftClusterSensorTrigger(BaseTrigger):
"""
RedshiftClusterSensorTrigger is fired as deferred class with params to run the task in trigger worker
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger` instead
:param task_id: Reference to task id of the Dag
:param aws_conn_id: Reference to AWS connection id for redshift
:param cluster_identifier: unique identifier of a cluster
Expand All @@ -127,14 +132,22 @@ def __init__(
target_status: str,
poke_interval: float,
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger`"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.task_id = task_id
self.aws_conn_id = aws_conn_id
self.cluster_identifier = cluster_identifier
self.target_status = target_status
self.poke_interval = poke_interval

def serialize(self) -> Tuple[str, Dict[str, Any]]:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes RedshiftClusterSensorTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterSensorTrigger",
Expand All @@ -147,7 +160,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Simple async function run until the cluster status match the target status."""
try:
hook = RedshiftHookAsync(aws_conn_id=self.aws_conn_id)
Expand Down
72 changes: 10 additions & 62 deletions tests/amazon/aws/sensors/test_redshift_sensor.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,20 @@
from unittest import mock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.sensors.redshift_cluster import (
RedshiftClusterSensor,
)

from astronomer.providers.amazon.aws.sensors.redshift_cluster import (
RedshiftClusterSensorAsync,
)
from astronomer.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterSensorTrigger,
)

TASK_ID = "redshift_sensor_check"
POLLING_PERIOD_SECONDS = 1.0

MODULE = "astronomer.providers.amazon.aws.sensors.redshift_cluster"


class TestRedshiftClusterSensorAsync:
TASK = RedshiftClusterSensorAsync(
task_id=TASK_ID,
cluster_identifier="astro-redshift-cluster-1",
target_status="available",
)

@mock.patch(f"{MODULE}.RedshiftClusterSensorAsync.defer")
@mock.patch(f"{MODULE}.RedshiftClusterSensorAsync.poke", return_value=True)
def test_redshift_cluster_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.TASK.execute(context)
assert not mock_defer.called

@mock.patch(f"{MODULE}.RedshiftClusterSensorAsync.poke", return_value=False)
def test_redshift_cluster_sensor_async(self, context):
"""Test RedshiftClusterSensorAsync that a task with wildcard=True
is deferred and an RedshiftClusterSensorTrigger will be fired when executed method is called"""

with pytest.raises(TaskDeferred) as exc:
self.TASK.execute(context)
assert isinstance(
exc.value.trigger, RedshiftClusterSensorTrigger
), "Trigger is not a RedshiftClusterSensorTrigger"

def test_redshift_sensor_async_execute_failure(self, context):
"""Test RedshiftClusterSensorAsync with an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
self.TASK.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

def test_redshift_sensor_async_execute_complete(self):
"""Asserts that logging occurs as expected"""

with mock.patch.object(self.TASK.log, "info") as mock_log_info:
self.TASK.execute_complete(
context=None, event={"status": "success", "cluster_state": "available"}
)
mock_log_info.assert_called_with(
"Cluster Identifier %s is in %s state", "astro-redshift-cluster-1", "available"
def test_init(self):
task = RedshiftClusterSensorAsync(
task_id=TASK_ID,
cluster_identifier="astro-redshift-cluster-1",
target_status="available",
)

def test_poll_interval_deprecation_warning(self):
"""Test DeprecationWarning for RedshiftClusterSensorAsync by setting param poll_interval"""
# TODO: Remove once deprecated
with pytest.warns(expected_warning=DeprecationWarning):
RedshiftClusterSensorAsync(
task_id=TASK_ID,
cluster_identifier="astro-redshift-cluster-1",
target_status="available",
poll_interval=5.0,
)
assert isinstance(task, RedshiftClusterSensor)
assert task.deferrable is True

0 comments on commit 2b47fd0

Please sign in to comment.