diff --git a/astronomer/providers/amazon/aws/sensors/batch.py b/astronomer/providers/amazon/aws/sensors/batch.py index 6f060a116..7c49018ae 100644 --- a/astronomer/providers/amazon/aws/sensors/batch.py +++ b/astronomer/providers/amazon/aws/sensors/batch.py @@ -1,68 +1,22 @@ import warnings -from datetime import timedelta -from typing import Any, Dict from airflow.providers.amazon.aws.sensors.batch import BatchSensor -from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger -from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception -from astronomer.providers.utils.typing_compat import Context - class BatchSensorAsync(BatchSensor): """ - Given a job ID of a Batch Job, poll for the job status asynchronously until it - reaches a failure or a success state. - If the job fails, the task will fail. - - .. see also:: - For more information on how to use this sensor, take a look at the guide: - :ref:`howto/sensor:BatchSensor` - - :param job_id: Batch job_id to check the state for - :param aws_conn_id: aws connection to use, defaults to 'aws_default' - :param region_name: region name to use in AWS Hook - Override the region_name in connection (if provided) - :param poll_interval: polling period in seconds to check for the status of the job + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.sensors.batch.BatchSensor`. """ - def __init__( - self, - *, - poll_interval: float = 5, - **kwargs: Any, - ): - # TODO: Remove once deprecated - if poll_interval: - self.poke_interval = poll_interval - warnings.warn( - "Argument `poll_interval` is deprecated and will be removed " - "in a future release. Please use `poke_interval` instead.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(**kwargs) - - def execute(self, context: Context) -> None: - """Defers trigger class to poll for state of the job run until it reaches a failure or a success state""" - if not poke(self, context): - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=BatchSensorTrigger( - job_id=self.job_id, - aws_conn_id=self.aws_conn_id, - region_name=self.region_name, - poke_interval=self.poke_interval, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Context, event: Dict[str, Any]) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if "status" in event and event["status"] == "error": - raise_error_or_skip_exception(self.soft_fail, event["message"]) - self.log.info(event["message"]) + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + warnings.warn( + ( + "This module is deprecated. " + "Please use `airflow.providers.amazon.aws.sensors.batch.BatchSensor` " + "and set deferrable to True instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return super().__init__(*args, deferrable=True, **kwargs) diff --git a/astronomer/providers/amazon/aws/triggers/batch.py b/astronomer/providers/amazon/aws/triggers/batch.py index 981e4314c..01b1be301 100644 --- a/astronomer/providers/amazon/aws/triggers/batch.py +++ b/astronomer/providers/amazon/aws/triggers/batch.py @@ -1,4 +1,3 @@ -import asyncio from typing import Any, AsyncIterator, Dict, Optional, Tuple from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -71,64 +70,3 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: yield TriggerEvent({"status": "error", "message": error_message}) except Exception as e: yield TriggerEvent({"status": "error", "message": str(e)}) - - -class BatchSensorTrigger(BaseTrigger): - """ - Checks for the status of a submitted job_id to AWS Batch until it reaches a failure or a success state. - BatchSensorTrigger is fired as deferred class with params to poll the job state in Triggerer - - :param job_id: the job ID, to poll for job completion or not - :param aws_conn_id: connection id of AWS credentials / region name. If None, - credential boto3 strategy will be used - :param region_name: AWS region name to use - Override the region_name in connection (if provided) - :param poke_interval: polling period in seconds to check for the status of the job - """ - - def __init__( - self, - job_id: str, - region_name: Optional[str], - aws_conn_id: Optional[str] = "aws_default", - poke_interval: float = 5, - ): - super().__init__() - self.job_id = job_id - self.aws_conn_id = aws_conn_id - self.region_name = region_name - self.poke_interval = poke_interval - - def serialize(self) -> Tuple[str, Dict[str, Any]]: - """Serializes BatchSensorTrigger arguments and classpath.""" - return ( - "astronomer.providers.amazon.aws.triggers.batch.BatchSensorTrigger", - { - "job_id": self.job_id, - "aws_conn_id": self.aws_conn_id, - "region_name": self.region_name, - "poke_interval": self.poke_interval, - }, - ) - - async def run(self) -> AsyncIterator["TriggerEvent"]: - """ - Make async connection using aiobotocore library to AWS Batch, - periodically poll for the Batch job status - - The status that indicates job completion are: 'SUCCEEDED'|'FAILED'. - """ - hook = BatchClientHookAsync(job_id=self.job_id, aws_conn_id=self.aws_conn_id) - try: - while True: - response = await hook.get_job_description(self.job_id) - state = response["status"] - if state == BatchClientHookAsync.SUCCESS_STATE: - success_message = f"{self.job_id} was completed successfully" - yield TriggerEvent({"status": "success", "message": success_message}) - if state == BatchClientHookAsync.FAILURE_STATE: - error_message = f"{self.job_id} failed" - yield TriggerEvent({"status": "error", "message": error_message}) - await asyncio.sleep(self.poke_interval) - except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/tests/amazon/aws/sensors/test_batch_sensors.py b/tests/amazon/aws/sensors/test_batch_sensors.py index f4fa6f19c..b7688f536 100644 --- a/tests/amazon/aws/sensors/test_batch_sensors.py +++ b/tests/amazon/aws/sensors/test_batch_sensors.py @@ -1,73 +1,17 @@ -from unittest import mock - -import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.providers.amazon.aws.sensors.batch import BatchSensor from astronomer.providers.amazon.aws.sensors.batch import BatchSensorAsync -from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger MODULE = "astronomer.providers.amazon.aws.sensors.batch" class TestBatchSensorAsync: - JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19" - AWS_CONN_ID = "airflow_test" - REGION_NAME = "eu-west-1" - TASK = BatchSensorAsync( - task_id="task", - job_id=JOB_ID, - aws_conn_id=AWS_CONN_ID, - region_name=REGION_NAME, - ) - - @mock.patch(f"{MODULE}.BatchSensorAsync.defer") - @mock.patch(f"{MODULE}.BatchSensorAsync.poke", return_value=True) - def test_batch_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context): - """Assert task is not deferred when it receives a finish status before deferring""" - self.TASK.execute(context) - assert not mock_defer.called - - @mock.patch(f"{MODULE}.BatchSensorAsync.poke", return_value=False) - def test_batch_sensor_async(self, context): - """ - Asserts that a task is deferred and a BatchSensorTrigger will be fired - when the BatchSensorAsync is executed. - """ - - with pytest.raises(TaskDeferred) as exc: - self.TASK.execute(context) - assert isinstance(exc.value.trigger, BatchSensorTrigger), "Trigger is not a BatchSensorTrigger" - - def test_batch_sensor_async_execute_failure(self, context): - """Tests that an AirflowException is raised in case of error event""" - - with pytest.raises(AirflowException) as exc_info: - self.TASK.execute_complete( - context=None, event={"status": "error", "message": "test failure message"} - ) - - assert str(exc_info.value) == "test failure message" - - @pytest.mark.parametrize( - "event", - [{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}], - ) - def test_batch_sensor_async_execute_complete(self, caplog, event): - """Tests that execute_complete method returns None and that it prints expected log""" - - with mock.patch.object(self.TASK.log, "info") as mock_log_info: - assert self.TASK.execute_complete(context=None, event=event) is None - - mock_log_info.assert_called_with(event["message"]) - - def test_poll_interval_deprecation_warning(self): - """Test DeprecationWarning for BatchSensorAsync by setting param poll_interval""" - # TODO: Remove once deprecated - with pytest.warns(expected_warning=DeprecationWarning): - BatchSensorAsync( - task_id="task", - job_id=self.JOB_ID, - aws_conn_id=self.AWS_CONN_ID, - region_name=self.REGION_NAME, - poll_interval=5.0, - ) + def test_init(self): + task = BatchSensorAsync( + task_id="task", + job_id="8ba9d676-4108-4474-9dca-8bbac1da9b19", + aws_conn_id="airflow_test", + region_name="eu-west-1", + ) + assert isinstance(task, BatchSensor) + assert task.deferrable is True diff --git a/tests/amazon/aws/triggers/test_batch.py b/tests/amazon/aws/triggers/test_batch.py index 67dba5e05..3731ac2bd 100644 --- a/tests/amazon/aws/triggers/test_batch.py +++ b/tests/amazon/aws/triggers/test_batch.py @@ -4,10 +4,7 @@ import pytest from airflow.triggers.base import TriggerEvent -from astronomer.providers.amazon.aws.triggers.batch import ( - BatchOperatorTrigger, - BatchSensorTrigger, -) +from astronomer.providers.amazon.aws.triggers.batch import BatchOperatorTrigger JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3" JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19" @@ -94,85 +91,3 @@ async def test_batch_trigger_exception(self, mock_response): task = [i async for i in self.TRIGGER.run()] assert len(task) == 1 assert TriggerEvent({"status": "error", "message": "Test exception"}) in task - - -class TestBatchSensorTrigger: - TRIGGER = BatchSensorTrigger( - job_id=JOB_ID, - region_name=REGION_NAME, - aws_conn_id=AWS_CONN_ID, - poke_interval=POKE_INTERVAL, - ) - - def test_batch_sensor_trigger_serialization(self): - """ - Asserts that the BatchSensorTrigger correctly serializes its arguments - and classpath. - """ - - classpath, kwargs = self.TRIGGER.serialize() - assert classpath == "astronomer.providers.amazon.aws.triggers.batch.BatchSensorTrigger" - assert kwargs == { - "job_id": JOB_ID, - "region_name": "eu-west-1", - "aws_conn_id": "airflow_test", - "poke_interval": POKE_INTERVAL, - } - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.batch_client.BatchClientHookAsync.get_job_description") - async def test_batch_sensor_trigger_run(self, mock_response): - """Trigger the BatchSensorTrigger and check if the task is in running state.""" - mock_response.return_value = {"status": "RUNNABLE"} - - task = asyncio.create_task(self.TRIGGER.run().__anext__()) - await asyncio.sleep(0.5) - # TriggerEvent was not returned - assert task.done() is False - asyncio.get_event_loop().stop() - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.batch_client.BatchClientHookAsync.get_job_description") - async def test_batch_sensor_trigger_completed(self, mock_response): - """Test if the success event is returned from trigger.""" - mock_response.return_value = {"status": "SUCCEEDED"} - trigger = BatchSensorTrigger( - job_id=JOB_ID, - region_name=REGION_NAME, - aws_conn_id=AWS_CONN_ID, - ) - generator = trigger.run() - actual_response = await generator.asend(None) - assert ( - TriggerEvent({"status": "success", "message": f"{JOB_ID} was completed successfully"}) - == actual_response - ) - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.batch_client.BatchClientHookAsync.get_job_description") - async def test_batch_sensor_trigger_failure(self, mock_response): - """Test if the failure event is returned from trigger.""" - mock_response.return_value = {"status": "FAILED"} - trigger = BatchSensorTrigger( - job_id=JOB_ID, - region_name=REGION_NAME, - aws_conn_id=AWS_CONN_ID, - ) - generator = trigger.run() - actual_response = await generator.asend(None) - assert TriggerEvent({"status": "error", "message": f"{JOB_ID} failed"}) == actual_response - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.batch_client.BatchClientHookAsync.get_job_description") - async def test_batch_sensor_trigger_exception(self, mock_response): - """Test if the exception is raised from trigger.""" - mock_response.side_effect = Exception("Test exception") - trigger = BatchSensorTrigger( - job_id=JOB_ID, - region_name=REGION_NAME, - aws_conn_id=AWS_CONN_ID, - ) - task = [i async for i in trigger.run()] - assert len(task) == 1 - - assert TriggerEvent({"status": "error", "message": "Test exception"}) in task