Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Depreacate EmrContainerSensorAsync and EmrStepSensorAsync #1390

Merged
merged 4 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def check_dag_status(**kwargs: Any) -> None:
# [START howto_sensor_emr_step_async]
watch_step = EmrStepSensorAsync(
task_id="watch_step",
job_flow_id=create_job_flow.output, # type: ignore[arg-type]
job_flow_id=create_job_flow.output,
step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}",
aws_conn_id=AWS_CONN_ID,
)
Expand Down
112 changes: 28 additions & 84 deletions astronomer/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from datetime import timedelta
from typing import Any

Expand All @@ -12,105 +13,48 @@
)

from astronomer.providers.amazon.aws.triggers.emr import (
EmrContainerSensorTrigger,
EmrJobFlowSensorTrigger,
EmrStepSensorTrigger,
)
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.sensor_util import raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


class EmrContainerSensorAsync(EmrContainerSensor):
"""
EmrContainerSensorAsync is async version of EmrContainerSensor,
Asks for the state of the job run until it reaches a failure state or success state.
If the job run fails, the task will fail.

:param virtual_cluster_id: Reference Emr cluster id
:param job_id: job_id to check the state
:param max_retries: Number of times to poll for query state before
returning the current state, defaults to None
:param aws_conn_id: aws connection to use, defaults to ``aws_default``
:param poll_interval: Time in seconds to wait between two consecutive call to
check query status on athena, defaults to 10
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor`.
"""

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=EmrContainerSensorTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
max_tries=self.max_retries,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)

# Ignoring the override type check because the parent class specifies "context: Any" but specifying it as
# "context: Context" is accurate as it's more specific.
def execute_complete(self, context: Context, event: dict[str, str]) -> None: # type: ignore[override]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event["message"])
return None
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
warnings.warn(
(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor` "
"and set deferrable to True instead."
),
DeprecationWarning,
stacklevel=2,
)
return super().__init__(*args, deferrable=True, **kwargs)


class EmrStepSensorAsync(EmrStepSensor):
"""
Async (deferring) version of EmrStepSensor

Asks for the state of the step until it reaches any of the target states.
If the sensor errors out, then the task will fail
With the default target states, sensor waits step to be COMPLETED.

For more details see
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_step

:param job_flow_id: job_flow_id which contains the step check the state of
:param step_id: step to check the state of
:param target_states: the target states, sensor waits until
step reaches any of these states
:param failed_states: the failure states, sensor fails when
step reaches any of these states
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.sensors.emr.EmrStepSensor`.
"""

def execute(self, context: Context) -> None:
"""Deferred and give control to trigger"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=EmrStepSensorTrigger(
job_flow_id=self.job_flow_id,
step_id=self.step_id,
target_states=self.target_states,
failed_states=self.failed_states,
aws_conn_id=self.aws_conn_id,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None: # type: ignore[override]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event.get("message"))
self.log.info("%s completed successfully.", self.job_flow_id)
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
warnings.warn(
(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.sensors.emr.EmrStepSensor` "
"and set deferrable to True instead."
),
DeprecationWarning,
stacklevel=2,
)
return super().__init__(*args, deferrable=True, **kwargs)


class EmrJobFlowSensorAsync(EmrJobFlowSensor):
Expand Down
108 changes: 22 additions & 86 deletions tests/amazon/aws/sensors/test_emr_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.sensors.emr import (
EmrStepSensor,
)

from astronomer.providers.amazon.aws.sensors.emr import (
EmrContainerSensorAsync,
EmrJobFlowSensorAsync,
EmrStepSensorAsync,
)
from astronomer.providers.amazon.aws.triggers.emr import (
EmrContainerSensorTrigger,
EmrJobFlowSensorTrigger,
EmrStepSensorTrigger,
)

TASK_ID = "test_emr_container_sensor"
Expand All @@ -28,55 +29,28 @@


class TestEmrContainerSensorAsync:
TASK = EmrContainerSensorAsync(
task_id=TASK_ID,
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
poll_interval=5,
max_retries=1,
aws_conn_id=AWS_CONN_ID,
)

@mock.patch(f"{MODULE}.EmrContainerSensorAsync.defer")
@mock.patch(f"{MODULE}.EmrContainerSensorAsync.poke", return_value=True)
def test_emr_container_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.TASK.execute(context)
assert not mock_defer.called

@mock.patch(f"{MODULE}.EmrContainerSensorAsync.poke", return_value=False)
def test_emr_container_sensor_async(self, mock_poke, context):
"""
Asserts that a task is deferred and a EmrContainerSensorTrigger will be fired
when the EmrContainerSensorAsync is executed.
"""

with pytest.raises(TaskDeferred) as exc:
self.TASK.execute(context)
assert isinstance(
exc.value.trigger, EmrContainerSensorTrigger
), "Trigger is not a EmrContainerSensorTrigger"

def test_emr_container_sensor_async_execute_failure(self, context):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
self.TASK.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

def test_emr_container_sensor_async_execute_complete(self):
"""Asserts that logging occurs as expected"""

assert (
self.TASK.execute_complete(context=None, event={"status": "success", "message": "Job completed"})
is None
def test_init(self):
task = EmrContainerSensorAsync(
task_id=TASK_ID,
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
poll_interval=5,
max_retries=1,
aws_conn_id=AWS_CONN_ID,
)
assert isinstance(task, EmrContainerSensorAsync)
assert task.deferrable is True

def test_emr_container_sensor_async_execute_complete_event_none(self):
"""Asserts that logging occurs as expected"""

assert self.TASK.execute_complete(context=None, event=None) is None
class TestEmrStepSensorAsync:
def test_init(self):
task = EmrStepSensorAsync(
task_id="emr_step_sensor",
job_flow_id=JOB_ID,
step_id=STEP_ID,
)
assert isinstance(task, EmrStepSensor)
assert task.deferrable is True


class TestEmrJobFlowSensorAsync:
Expand Down Expand Up @@ -140,41 +114,3 @@ def test_emr_job_flow_sensor_async_execute_complete_event_none(self):
"""Asserts that logging occurs as expected"""

assert self.TASK.execute_complete(context=None, event=None) is None


class TestEmrStepSensorAsync:
TASK = EmrStepSensorAsync(
task_id="emr_step_sensor",
job_flow_id=JOB_ID,
step_id=STEP_ID,
)

@mock.patch(f"{MODULE}.EmrStepSensorAsync.defer")
@mock.patch(f"{MODULE}.EmrStepSensorAsync.poke", return_value=True)
def test_emr_step_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.TASK.execute(context)
assert not mock_defer.called

@mock.patch(f"{MODULE}.EmrStepSensorAsync.poke", return_value=False)
def test_emr_step_sensor_async(self, mock_poke, context):
"""Assert execute method defer for EmrStepSensorAsync sensor"""

with pytest.raises(TaskDeferred) as exc:
self.TASK.execute(context)
assert isinstance(exc.value.trigger, EmrStepSensorTrigger), "Trigger is not a EmrStepSensorTrigger"

def test_emr_step_sensor_execute_complete_success(self):
"""Assert execute_complete log success message when triggerer fire with target state"""

with mock.patch.object(self.TASK.log, "info") as mock_log_info:
self.TASK.execute_complete(
context={}, event={"status": "success", "message": "Job flow currently COMPLETED"}
)
mock_log_info.assert_called_with("%s completed successfully.", "j-T0CT8Z0C20NT")

def test_emr_step_sensor_execute_complete_failure(self):
"""Assert execute_complete method fail"""

with pytest.raises(AirflowException):
self.TASK.execute_complete(context={}, event={"status": "error", "message": ""})