Skip to content

Commit

Permalink
Depreacate EmrContainerSensorAsync and EmrStepSensorAsync (#1390)
Browse files Browse the repository at this point in the history
* feat(amazon): deprecate EmrStepSensorAsync and EmrContainerSensorAsync
* feat(amazon): remove EmrContainerSensorTrigger and EmrStepSensorTrigger
  • Loading branch information
Lee-W authored Dec 21, 2023
1 parent 6781996 commit f0061d3
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 495 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def check_dag_status(**kwargs: Any) -> None:
# [START howto_sensor_emr_step_async]
watch_step = EmrStepSensorAsync(
task_id="watch_step",
job_flow_id=create_job_flow.output, # type: ignore[arg-type]
job_flow_id=create_job_flow.output,
step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}",
aws_conn_id=AWS_CONN_ID,
)
Expand Down
112 changes: 28 additions & 84 deletions astronomer/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from datetime import timedelta
from typing import Any

Expand All @@ -12,105 +13,48 @@
)

from astronomer.providers.amazon.aws.triggers.emr import (
EmrContainerSensorTrigger,
EmrJobFlowSensorTrigger,
EmrStepSensorTrigger,
)
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.sensor_util import raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context


class EmrContainerSensorAsync(EmrContainerSensor):
"""
EmrContainerSensorAsync is async version of EmrContainerSensor,
Asks for the state of the job run until it reaches a failure state or success state.
If the job run fails, the task will fail.
:param virtual_cluster_id: Reference Emr cluster id
:param job_id: job_id to check the state
:param max_retries: Number of times to poll for query state before
returning the current state, defaults to None
:param aws_conn_id: aws connection to use, defaults to ``aws_default``
:param poll_interval: Time in seconds to wait between two consecutive call to
check query status on athena, defaults to 10
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor`.
"""

def execute(self, context: Context) -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=EmrContainerSensorTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=self.job_id,
max_tries=self.max_retries,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)

# Ignoring the override type check because the parent class specifies "context: Any" but specifying it as
# "context: Context" is accurate as it's more specific.
def execute_complete(self, context: Context, event: dict[str, str]) -> None: # type: ignore[override]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event["message"])
return None
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
warnings.warn(
(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor` "
"and set deferrable to True instead."
),
DeprecationWarning,
stacklevel=2,
)
return super().__init__(*args, deferrable=True, **kwargs)


class EmrStepSensorAsync(EmrStepSensor):
"""
Async (deferring) version of EmrStepSensor
Asks for the state of the step until it reaches any of the target states.
If the sensor errors out, then the task will fail
With the default target states, sensor waits step to be COMPLETED.
For more details see
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_step
:param job_flow_id: job_flow_id which contains the step check the state of
:param step_id: step to check the state of
:param target_states: the target states, sensor waits until
step reaches any of these states
:param failed_states: the failure states, sensor fails when
step reaches any of these states
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.sensors.emr.EmrStepSensor`.
"""

def execute(self, context: Context) -> None:
"""Deferred and give control to trigger"""
if not poke(self, context):
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=EmrStepSensorTrigger(
job_flow_id=self.job_flow_id,
step_id=self.step_id,
target_states=self.target_states,
failed_states=self.failed_states,
aws_conn_id=self.aws_conn_id,
poke_interval=self.poke_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None: # type: ignore[override]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if event["status"] == "error":
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event.get("message"))
self.log.info("%s completed successfully.", self.job_flow_id)
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
warnings.warn(
(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.sensors.emr.EmrStepSensor` "
"and set deferrable to True instead."
),
DeprecationWarning,
stacklevel=2,
)
return super().__init__(*args, deferrable=True, **kwargs)


class EmrJobFlowSensorAsync(EmrJobFlowSensor):
Expand Down
114 changes: 0 additions & 114 deletions astronomer/providers/amazon/aws/triggers/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from astronomer.providers.amazon.aws.hooks.emr import (
EmrContainerHookAsync,
EmrJobFlowHookAsync,
EmrStepSensorHookAsync,
)


Expand Down Expand Up @@ -38,52 +37,6 @@ def __init__(
super().__init__(**kwargs)


class EmrContainerSensorTrigger(EmrContainerBaseTrigger):
"""Poll for the status of EMR container until reaches terminal state"""

def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes EmrContainerSensorTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger",
{
"virtual_cluster_id": self.virtual_cluster_id,
"job_id": self.job_id,
"aws_conn_id": self.aws_conn_id,
"max_tries": self.max_tries,
"poll_interval": self.poll_interval,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""Make async connection to EMR container, polls for the job state"""
hook = EmrContainerHookAsync(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)
try:
try_number: int = 1
while True:
query_status = await hook.check_job_status(job_id=self.job_id)
if query_status is None or query_status in ("PENDING", "SUBMITTED", "RUNNING"):
await asyncio.sleep(self.poll_interval)
elif query_status in ("FAILED", "CANCELLED", "CANCEL_PENDING"):
msg = f"EMR Containers sensor failed {query_status}"
yield TriggerEvent({"status": "error", "message": msg})
else:
msg = "EMR Containers sensors completed"
yield TriggerEvent({"status": "success", "message": msg})

if self.max_tries and try_number >= self.max_tries:
yield TriggerEvent(
{
"status": "error",
"message": "Timeout: Maximum retry limit exceed",
"job_id": self.job_id,
}
)

try_number += 1
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})


class EmrContainerOperatorTrigger(EmrContainerBaseTrigger):
"""Poll for the status of EMR container until reaches terminal state"""

Expand Down Expand Up @@ -161,73 +114,6 @@ async def run(self) -> AsyncIterator["TriggerEvent"]:
yield TriggerEvent({"status": "error", "message": str(e)})


class EmrStepSensorTrigger(BaseTrigger):
"""
A trigger that fires once AWS EMR cluster step reaches either target or failed state
:param job_flow_id: job_flow_id which contains the step check the state of
:param step_id: step to check the state of
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param poke_interval: Time in seconds to wait between two consecutive call to
check emr cluster step state
:param target_states: the target states, sensor waits until
step reaches any of these states
:param failed_states: the failure states, sensor fails when
step reaches any of these states
"""

def __init__(
self,
job_flow_id: str,
step_id: str,
aws_conn_id: str,
poke_interval: float,
target_states: Optional[Iterable[str]] = None,
failed_states: Optional[Iterable[str]] = None,
):
super().__init__()
self.job_flow_id = job_flow_id
self.step_id = step_id
self.aws_conn_id = aws_conn_id
self.poke_interval = poke_interval
self.target_states = target_states or ["COMPLETED"]
self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"]

def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes EmrStepSensorTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger",
{
"job_flow_id": self.job_flow_id,
"step_id": self.step_id,
"aws_conn_id": self.aws_conn_id,
"poke_interval": self.poke_interval,
"target_states": self.target_states,
"failed_states": self.failed_states,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""Run until AWS EMR cluster step reach target or failed state"""
hook = EmrStepSensorHookAsync(
aws_conn_id=self.aws_conn_id, job_flow_id=self.job_flow_id, step_id=self.step_id
)
try:
while True:
response = await hook.emr_describe_step()
state = hook.state_from_response(response)
if state in self.target_states:
yield TriggerEvent({"status": "success", "message": f"Job flow currently {state}"})
elif state in self.failed_states:
yield TriggerEvent(
{"status": "error", "message": hook.failure_message_from_response(response)}
)
self.log.info("EMR step state is %s. Sleeping for %s seconds.", state, self.poke_interval)
await asyncio.sleep(self.poke_interval)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})


class EmrJobFlowSensorTrigger(BaseTrigger):
"""
EmrJobFlowSensorTrigger is fired as deferred class with params to run the task in trigger worker, when
Expand Down
Loading

0 comments on commit f0061d3

Please sign in to comment.