Skip to content

Commit

Permalink
feat(providers/amazon): deprecate RedshiftDataOperatorAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Jan 25, 2024
1 parent bb61993 commit 19576cd
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 156 deletions.
13 changes: 13 additions & 0 deletions astronomer/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import warnings
from typing import Any, Iterable

import botocore.exceptions
Expand All @@ -18,6 +19,9 @@ class RedshiftDataHook(AwsBaseHook):
RedshiftDataHook inherits from AwsBaseHook to connect with AWS redshift
by using boto3 client_type as redshift-data we can interact with redshift cluster database and execute the query
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook` instead
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
Expand All @@ -34,6 +38,15 @@ class RedshiftDataHook(AwsBaseHook):
"""

def __init__(self, *args: Any, poll_interval: int = 0, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`"
),
DeprecationWarning,
stacklevel=2,
)

aws_connection_type: str = "redshift-data"
try:
# for apache-airflow-providers-amazon>=3.0.0
Expand Down
71 changes: 13 additions & 58 deletions astronomer/providers/amazon/aws/operators/redshift_data.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
import warnings
from typing import Any

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator

from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger
from astronomer.providers.utils.typing_compat import Context


class RedshiftDataOperatorAsync(RedshiftDataOperator):
"""
Executes SQL Statements against an Amazon Redshift cluster.
If there are multiple queries as part of the SQL, and one of them fails to reach a successful completion state,
the operator returns the relevant error for the failed query.
:param sql: the SQL code to be executed as a single string, or
a list of str (sql statements), or a reference to a template file.
Template references are recognized by str ending in '.sql'
:param aws_conn_id: AWS connection ID
:param parameters: (optional) the parameters to render the SQL query with.
:param autocommit: if True, each command is automatically committed.
(default value: False)
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator`
and set `deferrable` param to `True` instead.
"""

def __init__(
Expand All @@ -29,47 +17,14 @@ def __init__(
poll_interval: int = 5,
**kwargs: Any,
) -> None:
self.poll_interval = poll_interval
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
"""
Makes a sync call to RedshiftDataHook, executes the query and gets back the list of query_ids and
defers trigger to poll for the status for the queries executed.
"""
redshift_data_hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id)
query_ids, response = redshift_data_hook.execute_query(sql=self.sql, params=self.params)
self.log.info("Query IDs %s", query_ids)
if response.get("status") == "error":
self.execute_complete(context, event=response)
context["ti"].xcom_push(key="return_value", value=query_ids)

if redshift_data_hook.queries_are_completed(query_ids, context):
self.log.info("%s completed successfully.", self.task_id)
return

self.defer(
timeout=self.execution_timeout,
trigger=RedshiftDataTrigger(
task_id=self.task_id,
poll_interval=self.poll_interval,
aws_conn_id=self.aws_conn_id,
query_ids=query_ids,
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator`"
"and set `deferrable` param to `True` instead."
),
method_name="execute_complete",
DeprecationWarning,
stacklevel=2,
)

def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = "context: {}, error message: {}".format(context, event["message"])
raise AirflowException(msg)
elif "status" in event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
else:
raise AirflowException("Did not receive valid event from the trigerrer")
kwargs["poll_interval"] = poll_interval
super().__init__(deferrable=True, **kwargs)
23 changes: 19 additions & 4 deletions astronomer/providers/amazon/aws/triggers/redshift_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, AsyncIterator, Dict, List, Tuple
from __future__ import annotations

import warnings
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand All @@ -9,6 +12,9 @@ class RedshiftDataTrigger(BaseTrigger):
"""
RedshiftDataTrigger is fired as deferred class with params to run the task in triggerer.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger` instead
:param task_id: task ID of the Dag
:param poll_interval: polling period in seconds to check for the status
:param aws_conn_id: AWS connection ID for redshift
Expand All @@ -19,16 +25,25 @@ def __init__(
self,
task_id: str,
poll_interval: int,
query_ids: List[str],
query_ids: list[str],
aws_conn_id: str = "aws_default",
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger`"
),
DeprecationWarning,
stacklevel=2,
)

super().__init__()
self.task_id = task_id
self.poll_interval = poll_interval
self.aws_conn_id = aws_conn_id
self.query_ids = query_ids

def serialize(self) -> Tuple[str, Dict[str, Any]]:
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes RedshiftDataTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger",
Expand All @@ -40,7 +55,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]:
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Makes async connection and gets status for a list of queries submitted by the operator.
Even if one of the queries has a non-successful state, the hook returns a failure event and the error
Expand Down
103 changes: 9 additions & 94 deletions tests/amazon/aws/operators/test_redshift_data.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,16 @@
from unittest import mock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator

from astronomer.providers.amazon.aws.operators.redshift_data import (
RedshiftDataOperatorAsync,
)
from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger
from tests.utils.airflow_util import create_context


class TestRedshiftDataOperatorAsync:
DATABASE_NAME = "TEST_DATABASE"
TASK_ID = "fetch_data"
SQL_QUERY = "select * from any"
TASK = RedshiftDataOperatorAsync(
task_id=TASK_ID,
sql=SQL_QUERY,
database=DATABASE_NAME,
)

@mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer")
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_redshift_data_op_async_finished_before_deferred(self, mock_execute, mock_conn, mock_defer):
mock_execute.return_value = ["test_query_id"], {}
mock_conn.describe_statement.return_value = {
"Status": "FINISHED",
}
self.TASK.execute(create_context(self.TASK))
assert not mock_defer.called

@mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer")
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_redshift_data_op_async_aborted_before_deferred(self, mock_execute, mock_conn, mock_defer):
mock_execute.return_value = ["test_query_id"], {}
mock_conn.describe_statement.return_value = {"Status": "ABORTED"}

with pytest.raises(AirflowException):
self.TASK.execute(create_context(self.TASK))

assert not mock_defer.called

@mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer")
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_redshift_data_op_async_failed_before_deferred(self, mock_execute, mock_conn, mock_defer):
mock_execute.return_value = ["test_query_id"], {}
mock_conn.describe_statement.return_value = {
"Status": "FAILED",
"QueryString": "test query",
"Error": "test error",
}

with pytest.raises(AirflowException):
self.TASK.execute(create_context(self.TASK))

assert not mock_defer.called

@pytest.mark.parametrize("status", ("SUBMITTED", "PICKED", "STARTED"))
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_redshift_data_op_async(self, mock_execute, mock_conn, status):
mock_execute.return_value = ["test_query_id"], {}
mock_conn.describe_statement.return_value = {"Status": status}

with pytest.raises(TaskDeferred) as exc:
self.TASK.execute(create_context(self.TASK))
assert isinstance(exc.value.trigger, RedshiftDataTrigger), "Trigger is not a RedshiftDataTrigger"

@mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_redshift_data_op_async_execute_query_error(self, mock_execute, context):
mock_execute.return_value = [], {"status": "error", "message": "Test exception"}
with pytest.raises(AirflowException):
self.TASK.execute(context)

def test_redshift_data_op_async_execute_failure(self, context):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
self.TASK.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

@pytest.mark.parametrize(
"event",
[None, {"status": "success", "message": "Job completed"}],
)
def test_redshift_data_op_async_execute_complete(self, event):
"""Asserts that logging occurs as expected"""

if not event:
with pytest.raises(AirflowException) as exception_info:
self.TASK.execute_complete(context=None, event=None)
assert exception_info.value.args[0] == "Did not receive valid event from the trigerrer"
else:
with mock.patch.object(self.TASK.log, "info") as mock_log_info:
self.TASK.execute_complete(context=None, event=event)
mock_log_info.assert_called_with("%s completed successfully.", self.TASK_ID)
def test_init(self):
task = RedshiftDataOperatorAsync(
task_id="fetch_data",
sql="select * from any",
database="TEST_DATABASE",
)
assert isinstance(task, RedshiftDataOperator)
assert task.deferrable is True

0 comments on commit 19576cd

Please sign in to comment.