Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate BatchSensorAsync #1391

Merged
merged 3 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 13 additions & 59 deletions astronomer/providers/amazon/aws/sensors/batch.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,22 @@
import warnings
from datetime import timedelta
from typing import Any, Dict

from airflow.providers.amazon.aws.sensors.batch import BatchSensor

from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


class BatchSensorAsync(BatchSensor):
"""
Given a job ID of a Batch Job, poll for the job status asynchronously until it
reaches a failure or a success state.
If the job fails, the task will fail.

.. see also::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:BatchSensor`

:param job_id: Batch job_id to check the state for
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param region_name: region name to use in AWS Hook
Override the region_name in connection (if provided)
:param poll_interval: polling period in seconds to check for the status of the job
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.sensors.batch.BatchSensor`.
"""

def __init__(
self,
*,
poll_interval: float = 5,
**kwargs: Any,
):
# TODO: Remove once deprecated
if poll_interval:
self.poke_interval = poll_interval
warnings.warn(
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(**kwargs)

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure or a success state"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=BatchSensorTrigger(
job_id=self.job_id,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: Dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if "status" in event and event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event["message"])
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
warnings.warn(
(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.sensors.batch.BatchSensor` "
"and set deferrable to True instead."
),
DeprecationWarning,
stacklevel=2,
)
return super().__init__(*args, deferrable=True, **kwargs)
76 changes: 10 additions & 66 deletions tests/amazon/aws/sensors/test_batch_sensors.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,17 @@
from unittest import mock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.sensors.batch import BatchSensor

from astronomer.providers.amazon.aws.sensors.batch import BatchSensorAsync
from astronomer.providers.amazon.aws.triggers.batch import BatchSensorTrigger

MODULE = "astronomer.providers.amazon.aws.sensors.batch"


class TestBatchSensorAsync:
JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
AWS_CONN_ID = "airflow_test"
REGION_NAME = "eu-west-1"
TASK = BatchSensorAsync(
task_id="task",
job_id=JOB_ID,
aws_conn_id=AWS_CONN_ID,
region_name=REGION_NAME,
)

@mock.patch(f"{MODULE}.BatchSensorAsync.defer")
@mock.patch(f"{MODULE}.BatchSensorAsync.poke", return_value=True)
def test_batch_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.TASK.execute(context)
assert not mock_defer.called

@mock.patch(f"{MODULE}.BatchSensorAsync.poke", return_value=False)
def test_batch_sensor_async(self, context):
"""
Asserts that a task is deferred and a BatchSensorTrigger will be fired
when the BatchSensorAsync is executed.
"""

with pytest.raises(TaskDeferred) as exc:
self.TASK.execute(context)
assert isinstance(exc.value.trigger, BatchSensorTrigger), "Trigger is not a BatchSensorTrigger"

def test_batch_sensor_async_execute_failure(self, context):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException) as exc_info:
self.TASK.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

assert str(exc_info.value) == "test failure message"

@pytest.mark.parametrize(
"event",
[{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}],
)
def test_batch_sensor_async_execute_complete(self, caplog, event):
"""Tests that execute_complete method returns None and that it prints expected log"""

with mock.patch.object(self.TASK.log, "info") as mock_log_info:
assert self.TASK.execute_complete(context=None, event=event) is None

mock_log_info.assert_called_with(event["message"])

def test_poll_interval_deprecation_warning(self):
"""Test DeprecationWarning for BatchSensorAsync by setting param poll_interval"""
# TODO: Remove once deprecated
with pytest.warns(expected_warning=DeprecationWarning):
BatchSensorAsync(
task_id="task",
job_id=self.JOB_ID,
aws_conn_id=self.AWS_CONN_ID,
region_name=self.REGION_NAME,
poll_interval=5.0,
)
def test_init(self):
task = BatchSensorAsync(
task_id="task",
job_id="8ba9d676-4108-4474-9dca-8bbac1da9b19",
aws_conn_id="airflow_test",
region_name="eu-west-1",
)
assert isinstance(task, BatchSensor)
assert task.deferrable is True