diff --git a/astronomer/providers/amazon/aws/hooks/redshift_data.py b/astronomer/providers/amazon/aws/hooks/redshift_data.py index d7b880c1f..0ae56bba1 100644 --- a/astronomer/providers/amazon/aws/hooks/redshift_data.py +++ b/astronomer/providers/amazon/aws/hooks/redshift_data.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import warnings from typing import Any, Iterable import botocore.exceptions @@ -18,6 +19,9 @@ class RedshiftDataHook(AwsBaseHook): RedshiftDataHook inherits from AwsBaseHook to connect with AWS redshift by using boto3 client_type as redshift-data we can interact with redshift cluster database and execute the query + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook` instead + :param aws_conn_id: The Airflow connection used for AWS credentials. If this is None or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or @@ -34,6 +38,15 @@ class RedshiftDataHook(AwsBaseHook): """ def __init__(self, *args: Any, poll_interval: int = 0, **kwargs: Any) -> None: + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`" + ), + DeprecationWarning, + stacklevel=2, + ) + aws_connection_type: str = "redshift-data" try: # for apache-airflow-providers-amazon>=3.0.0 diff --git a/astronomer/providers/amazon/aws/operators/redshift_data.py b/astronomer/providers/amazon/aws/operators/redshift_data.py index 6dd05d2fa..39aae4ae4 100644 --- a/astronomer/providers/amazon/aws/operators/redshift_data.py +++ b/astronomer/providers/amazon/aws/operators/redshift_data.py @@ -1,26 +1,14 @@ +import warnings from typing import Any -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator -from astronomer.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook -from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger -from astronomer.providers.utils.typing_compat import Context - class RedshiftDataOperatorAsync(RedshiftDataOperator): """ - Executes SQL Statements against an Amazon Redshift cluster. - If there are multiple queries as part of the SQL, and one of them fails to reach a successful completion state, - the operator returns the relevant error for the failed query. - - :param sql: the SQL code to be executed as a single string, or - a list of str (sql statements), or a reference to a template file. - Template references are recognized by str ending in '.sql' - :param aws_conn_id: AWS connection ID - :param parameters: (optional) the parameters to render the SQL query with. - :param autocommit: if True, each command is automatically committed. - (default value: False) + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator` + and set `deferrable` param to `True` instead. """ def __init__( @@ -29,47 +17,14 @@ def __init__( poll_interval: int = 5, **kwargs: Any, ) -> None: - self.poll_interval = poll_interval - super().__init__(**kwargs) - - def execute(self, context: Context) -> None: - """ - Makes a sync call to RedshiftDataHook, executes the query and gets back the list of query_ids and - defers trigger to poll for the status for the queries executed. - """ - redshift_data_hook = RedshiftDataHook(aws_conn_id=self.aws_conn_id) - query_ids, response = redshift_data_hook.execute_query(sql=self.sql, params=self.params) - self.log.info("Query IDs %s", query_ids) - if response.get("status") == "error": - self.execute_complete(context, event=response) - context["ti"].xcom_push(key="return_value", value=query_ids) - - if redshift_data_hook.queries_are_completed(query_ids, context): - self.log.info("%s completed successfully.", self.task_id) - return - - self.defer( - timeout=self.execution_timeout, - trigger=RedshiftDataTrigger( - task_id=self.task_id, - poll_interval=self.poll_interval, - aws_conn_id=self.aws_conn_id, - query_ids=query_ids, + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator`" + "and set `deferrable` param to `True` instead." ), - method_name="execute_complete", + DeprecationWarning, + stacklevel=2, ) - - def execute_complete(self, context: Context, event: Any = None) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if "status" in event and event["status"] == "error": - msg = "context: {}, error message: {}".format(context, event["message"]) - raise AirflowException(msg) - elif "status" in event and event["status"] == "success": - self.log.info("%s completed successfully.", self.task_id) - else: - raise AirflowException("Did not receive valid event from the trigerrer") + kwargs["poll_interval"] = poll_interval + super().__init__(deferrable=True, **kwargs) diff --git a/astronomer/providers/amazon/aws/triggers/redshift_data.py b/astronomer/providers/amazon/aws/triggers/redshift_data.py index c3e6790c9..540135fd7 100644 --- a/astronomer/providers/amazon/aws/triggers/redshift_data.py +++ b/astronomer/providers/amazon/aws/triggers/redshift_data.py @@ -1,4 +1,7 @@ -from typing import Any, AsyncIterator, Dict, List, Tuple +from __future__ import annotations + +import warnings +from typing import Any, AsyncIterator from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -9,6 +12,9 @@ class RedshiftDataTrigger(BaseTrigger): """ RedshiftDataTrigger is fired as deferred class with params to run the task in triggerer. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger` instead + :param task_id: task ID of the Dag :param poll_interval: polling period in seconds to check for the status :param aws_conn_id: AWS connection ID for redshift @@ -19,16 +25,25 @@ def __init__( self, task_id: str, poll_interval: int, - query_ids: List[str], + query_ids: list[str], aws_conn_id: str = "aws_default", ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) + super().__init__() self.task_id = task_id self.poll_interval = poll_interval self.aws_conn_id = aws_conn_id self.query_ids = query_ids - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes RedshiftDataTrigger arguments and classpath.""" return ( "astronomer.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger", @@ -40,7 +55,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """ Makes async connection and gets status for a list of queries submitted by the operator. Even if one of the queries has a non-successful state, the hook returns a failure event and the error diff --git a/tests/amazon/aws/operators/test_redshift_data.py b/tests/amazon/aws/operators/test_redshift_data.py index 5cff1b4f8..8c2860695 100644 --- a/tests/amazon/aws/operators/test_redshift_data.py +++ b/tests/amazon/aws/operators/test_redshift_data.py @@ -1,101 +1,16 @@ -from unittest import mock - -import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator from astronomer.providers.amazon.aws.operators.redshift_data import ( RedshiftDataOperatorAsync, ) -from astronomer.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger -from tests.utils.airflow_util import create_context class TestRedshiftDataOperatorAsync: - DATABASE_NAME = "TEST_DATABASE" - TASK_ID = "fetch_data" - SQL_QUERY = "select * from any" - TASK = RedshiftDataOperatorAsync( - task_id=TASK_ID, - sql=SQL_QUERY, - database=DATABASE_NAME, - ) - - @mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_redshift_data_op_async_finished_before_deferred(self, mock_execute, mock_conn, mock_defer): - mock_execute.return_value = ["test_query_id"], {} - mock_conn.describe_statement.return_value = { - "Status": "FINISHED", - } - self.TASK.execute(create_context(self.TASK)) - assert not mock_defer.called - - @mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_redshift_data_op_async_aborted_before_deferred(self, mock_execute, mock_conn, mock_defer): - mock_execute.return_value = ["test_query_id"], {} - mock_conn.describe_statement.return_value = {"Status": "ABORTED"} - - with pytest.raises(AirflowException): - self.TASK.execute(create_context(self.TASK)) - - assert not mock_defer.called - - @mock.patch("astronomer.providers.amazon.aws.operators.redshift_data.RedshiftDataOperatorAsync.defer") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_redshift_data_op_async_failed_before_deferred(self, mock_execute, mock_conn, mock_defer): - mock_execute.return_value = ["test_query_id"], {} - mock_conn.describe_statement.return_value = { - "Status": "FAILED", - "QueryString": "test query", - "Error": "test error", - } - - with pytest.raises(AirflowException): - self.TASK.execute(create_context(self.TASK)) - - assert not mock_defer.called - - @pytest.mark.parametrize("status", ("SUBMITTED", "PICKED", "STARTED")) - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_redshift_data_op_async(self, mock_execute, mock_conn, status): - mock_execute.return_value = ["test_query_id"], {} - mock_conn.describe_statement.return_value = {"Status": status} - - with pytest.raises(TaskDeferred) as exc: - self.TASK.execute(create_context(self.TASK)) - assert isinstance(exc.value.trigger, RedshiftDataTrigger), "Trigger is not a RedshiftDataTrigger" - - @mock.patch("astronomer.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") - def test_redshift_data_op_async_execute_query_error(self, mock_execute, context): - mock_execute.return_value = [], {"status": "error", "message": "Test exception"} - with pytest.raises(AirflowException): - self.TASK.execute(context) - - def test_redshift_data_op_async_execute_failure(self, context): - """Tests that an AirflowException is raised in case of error event""" - - with pytest.raises(AirflowException): - self.TASK.execute_complete( - context=None, event={"status": "error", "message": "test failure message"} - ) - - @pytest.mark.parametrize( - "event", - [None, {"status": "success", "message": "Job completed"}], - ) - def test_redshift_data_op_async_execute_complete(self, event): - """Asserts that logging occurs as expected""" - - if not event: - with pytest.raises(AirflowException) as exception_info: - self.TASK.execute_complete(context=None, event=None) - assert exception_info.value.args[0] == "Did not receive valid event from the trigerrer" - else: - with mock.patch.object(self.TASK.log, "info") as mock_log_info: - self.TASK.execute_complete(context=None, event=event) - mock_log_info.assert_called_with("%s completed successfully.", self.TASK_ID) + def test_init(self): + task = RedshiftDataOperatorAsync( + task_id="fetch_data", + sql="select * from any", + database="TEST_DATABASE", + ) + assert isinstance(task, RedshiftDataOperator) + assert task.deferrable is True