From f0061d3f32f66fb8a547eeb89e5dac45e2cec8b1 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Thu, 21 Dec 2023 12:06:21 +0530 Subject: [PATCH] Depreacate EmrContainerSensorAsync and EmrStepSensorAsync (#1390) * feat(amazon): deprecate EmrStepSensorAsync and EmrContainerSensorAsync * feat(amazon): remove EmrContainerSensorTrigger and EmrStepSensorTrigger --- .../aws/example_dags/example_emr_sensor.py | 2 +- .../providers/amazon/aws/sensors/emr.py | 112 +++------- .../providers/amazon/aws/triggers/emr.py | 114 ---------- tests/amazon/aws/sensors/test_emr_sensors.py | 108 ++------- tests/amazon/aws/triggers/test_emr.py | 210 ------------------ 5 files changed, 51 insertions(+), 495 deletions(-) diff --git a/astronomer/providers/amazon/aws/example_dags/example_emr_sensor.py b/astronomer/providers/amazon/aws/example_dags/example_emr_sensor.py index b5cee45d5..6baf3fb1a 100644 --- a/astronomer/providers/amazon/aws/example_dags/example_emr_sensor.py +++ b/astronomer/providers/amazon/aws/example_dags/example_emr_sensor.py @@ -109,7 +109,7 @@ def check_dag_status(**kwargs: Any) -> None: # [START howto_sensor_emr_step_async] watch_step = EmrStepSensorAsync( task_id="watch_step", - job_flow_id=create_job_flow.output, # type: ignore[arg-type] + job_flow_id=create_job_flow.output, step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}", aws_conn_id=AWS_CONN_ID, ) diff --git a/astronomer/providers/amazon/aws/sensors/emr.py b/astronomer/providers/amazon/aws/sensors/emr.py index ff2e21934..bf1a4e059 100644 --- a/astronomer/providers/amazon/aws/sensors/emr.py +++ b/astronomer/providers/amazon/aws/sensors/emr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from datetime import timedelta from typing import Any @@ -12,105 +13,48 @@ ) from astronomer.providers.amazon.aws.triggers.emr import ( - EmrContainerSensorTrigger, EmrJobFlowSensorTrigger, - EmrStepSensorTrigger, ) -from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception +from astronomer.providers.utils.sensor_util import raise_error_or_skip_exception from astronomer.providers.utils.typing_compat import Context class EmrContainerSensorAsync(EmrContainerSensor): """ - EmrContainerSensorAsync is async version of EmrContainerSensor, - Asks for the state of the job run until it reaches a failure state or success state. - If the job run fails, the task will fail. - - :param virtual_cluster_id: Reference Emr cluster id - :param job_id: job_id to check the state - :param max_retries: Number of times to poll for query state before - returning the current state, defaults to None - :param aws_conn_id: aws connection to use, defaults to ``aws_default`` - :param poll_interval: Time in seconds to wait between two consecutive call to - check query status on athena, defaults to 10 + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor`. """ - def execute(self, context: Context) -> None: - """Defers trigger class to poll for state of the job run until it reaches a failure state or success state""" - if not poke(self, context): - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=EmrContainerSensorTrigger( - virtual_cluster_id=self.virtual_cluster_id, - job_id=self.job_id, - max_tries=self.max_retries, - aws_conn_id=self.aws_conn_id, - poll_interval=self.poll_interval, - ), - method_name="execute_complete", - ) - - # Ignoring the override type check because the parent class specifies "context: Any" but specifying it as - # "context: Context" is accurate as it's more specific. - def execute_complete(self, context: Context, event: dict[str, str]) -> None: # type: ignore[override] - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if event["status"] == "error": - raise_error_or_skip_exception(self.soft_fail, event["message"]) - self.log.info(event["message"]) - return None + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + warnings.warn( + ( + "This module is deprecated. " + "Please use `airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor` " + "and set deferrable to True instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return super().__init__(*args, deferrable=True, **kwargs) class EmrStepSensorAsync(EmrStepSensor): """ - Async (deferring) version of EmrStepSensor - - Asks for the state of the step until it reaches any of the target states. - If the sensor errors out, then the task will fail - With the default target states, sensor waits step to be COMPLETED. - - For more details see - - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_step - - :param job_flow_id: job_flow_id which contains the step check the state of - :param step_id: step to check the state of - :param target_states: the target states, sensor waits until - step reaches any of these states - :param failed_states: the failure states, sensor fails when - step reaches any of these states + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.sensors.emr.EmrStepSensor`. """ - def execute(self, context: Context) -> None: - """Deferred and give control to trigger""" - if not poke(self, context): - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=EmrStepSensorTrigger( - job_flow_id=self.job_flow_id, - step_id=self.step_id, - target_states=self.target_states, - failed_states=self.failed_states, - aws_conn_id=self.aws_conn_id, - poke_interval=self.poke_interval, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: Context, event: dict[str, Any]) -> None: # type: ignore[override] - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event: - if event["status"] == "error": - raise_error_or_skip_exception(self.soft_fail, event["message"]) - self.log.info(event.get("message")) - self.log.info("%s completed successfully.", self.job_flow_id) + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + warnings.warn( + ( + "This module is deprecated. " + "Please use `airflow.providers.amazon.aws.sensors.emr.EmrStepSensor` " + "and set deferrable to True instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return super().__init__(*args, deferrable=True, **kwargs) class EmrJobFlowSensorAsync(EmrJobFlowSensor): diff --git a/astronomer/providers/amazon/aws/triggers/emr.py b/astronomer/providers/amazon/aws/triggers/emr.py index 476c18a6b..7ca1f9d32 100644 --- a/astronomer/providers/amazon/aws/triggers/emr.py +++ b/astronomer/providers/amazon/aws/triggers/emr.py @@ -6,7 +6,6 @@ from astronomer.providers.amazon.aws.hooks.emr import ( EmrContainerHookAsync, EmrJobFlowHookAsync, - EmrStepSensorHookAsync, ) @@ -38,52 +37,6 @@ def __init__( super().__init__(**kwargs) -class EmrContainerSensorTrigger(EmrContainerBaseTrigger): - """Poll for the status of EMR container until reaches terminal state""" - - def serialize(self) -> Tuple[str, Dict[str, Any]]: - """Serializes EmrContainerSensorTrigger arguments and classpath.""" - return ( - "astronomer.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger", - { - "virtual_cluster_id": self.virtual_cluster_id, - "job_id": self.job_id, - "aws_conn_id": self.aws_conn_id, - "max_tries": self.max_tries, - "poll_interval": self.poll_interval, - }, - ) - - async def run(self) -> AsyncIterator["TriggerEvent"]: - """Make async connection to EMR container, polls for the job state""" - hook = EmrContainerHookAsync(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) - try: - try_number: int = 1 - while True: - query_status = await hook.check_job_status(job_id=self.job_id) - if query_status is None or query_status in ("PENDING", "SUBMITTED", "RUNNING"): - await asyncio.sleep(self.poll_interval) - elif query_status in ("FAILED", "CANCELLED", "CANCEL_PENDING"): - msg = f"EMR Containers sensor failed {query_status}" - yield TriggerEvent({"status": "error", "message": msg}) - else: - msg = "EMR Containers sensors completed" - yield TriggerEvent({"status": "success", "message": msg}) - - if self.max_tries and try_number >= self.max_tries: - yield TriggerEvent( - { - "status": "error", - "message": "Timeout: Maximum retry limit exceed", - "job_id": self.job_id, - } - ) - - try_number += 1 - except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e)}) - - class EmrContainerOperatorTrigger(EmrContainerBaseTrigger): """Poll for the status of EMR container until reaches terminal state""" @@ -161,73 +114,6 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: yield TriggerEvent({"status": "error", "message": str(e)}) -class EmrStepSensorTrigger(BaseTrigger): - """ - A trigger that fires once AWS EMR cluster step reaches either target or failed state - - :param job_flow_id: job_flow_id which contains the step check the state of - :param step_id: step to check the state of - :param aws_conn_id: aws connection to use, defaults to 'aws_default' - :param poke_interval: Time in seconds to wait between two consecutive call to - check emr cluster step state - :param target_states: the target states, sensor waits until - step reaches any of these states - :param failed_states: the failure states, sensor fails when - step reaches any of these states - """ - - def __init__( - self, - job_flow_id: str, - step_id: str, - aws_conn_id: str, - poke_interval: float, - target_states: Optional[Iterable[str]] = None, - failed_states: Optional[Iterable[str]] = None, - ): - super().__init__() - self.job_flow_id = job_flow_id - self.step_id = step_id - self.aws_conn_id = aws_conn_id - self.poke_interval = poke_interval - self.target_states = target_states or ["COMPLETED"] - self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"] - - def serialize(self) -> Tuple[str, Dict[str, Any]]: - """Serializes EmrStepSensorTrigger arguments and classpath.""" - return ( - "astronomer.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger", - { - "job_flow_id": self.job_flow_id, - "step_id": self.step_id, - "aws_conn_id": self.aws_conn_id, - "poke_interval": self.poke_interval, - "target_states": self.target_states, - "failed_states": self.failed_states, - }, - ) - - async def run(self) -> AsyncIterator["TriggerEvent"]: - """Run until AWS EMR cluster step reach target or failed state""" - hook = EmrStepSensorHookAsync( - aws_conn_id=self.aws_conn_id, job_flow_id=self.job_flow_id, step_id=self.step_id - ) - try: - while True: - response = await hook.emr_describe_step() - state = hook.state_from_response(response) - if state in self.target_states: - yield TriggerEvent({"status": "success", "message": f"Job flow currently {state}"}) - elif state in self.failed_states: - yield TriggerEvent( - {"status": "error", "message": hook.failure_message_from_response(response)} - ) - self.log.info("EMR step state is %s. Sleeping for %s seconds.", state, self.poke_interval) - await asyncio.sleep(self.poke_interval) - except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e)}) - - class EmrJobFlowSensorTrigger(BaseTrigger): """ EmrJobFlowSensorTrigger is fired as deferred class with params to run the task in trigger worker, when diff --git a/tests/amazon/aws/sensors/test_emr_sensors.py b/tests/amazon/aws/sensors/test_emr_sensors.py index efff4f815..182ad93fb 100644 --- a/tests/amazon/aws/sensors/test_emr_sensors.py +++ b/tests/amazon/aws/sensors/test_emr_sensors.py @@ -3,6 +3,9 @@ import pytest from airflow.exceptions import AirflowException, TaskDeferred +from airflow.providers.amazon.aws.sensors.emr import ( + EmrStepSensor, +) from astronomer.providers.amazon.aws.sensors.emr import ( EmrContainerSensorAsync, @@ -10,9 +13,7 @@ EmrStepSensorAsync, ) from astronomer.providers.amazon.aws.triggers.emr import ( - EmrContainerSensorTrigger, EmrJobFlowSensorTrigger, - EmrStepSensorTrigger, ) TASK_ID = "test_emr_container_sensor" @@ -28,55 +29,28 @@ class TestEmrContainerSensorAsync: - TASK = EmrContainerSensorAsync( - task_id=TASK_ID, - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - job_id=JOB_ID, - poll_interval=5, - max_retries=1, - aws_conn_id=AWS_CONN_ID, - ) - - @mock.patch(f"{MODULE}.EmrContainerSensorAsync.defer") - @mock.patch(f"{MODULE}.EmrContainerSensorAsync.poke", return_value=True) - def test_emr_container_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context): - """Assert task is not deferred when it receives a finish status before deferring""" - self.TASK.execute(context) - assert not mock_defer.called - - @mock.patch(f"{MODULE}.EmrContainerSensorAsync.poke", return_value=False) - def test_emr_container_sensor_async(self, mock_poke, context): - """ - Asserts that a task is deferred and a EmrContainerSensorTrigger will be fired - when the EmrContainerSensorAsync is executed. - """ - - with pytest.raises(TaskDeferred) as exc: - self.TASK.execute(context) - assert isinstance( - exc.value.trigger, EmrContainerSensorTrigger - ), "Trigger is not a EmrContainerSensorTrigger" - - def test_emr_container_sensor_async_execute_failure(self, context): - """Tests that an AirflowException is raised in case of error event""" - - with pytest.raises(AirflowException): - self.TASK.execute_complete( - context=None, event={"status": "error", "message": "test failure message"} - ) - - def test_emr_container_sensor_async_execute_complete(self): - """Asserts that logging occurs as expected""" - - assert ( - self.TASK.execute_complete(context=None, event={"status": "success", "message": "Job completed"}) - is None + def test_init(self): + task = EmrContainerSensorAsync( + task_id=TASK_ID, + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + poll_interval=5, + max_retries=1, + aws_conn_id=AWS_CONN_ID, ) + assert isinstance(task, EmrContainerSensorAsync) + assert task.deferrable is True - def test_emr_container_sensor_async_execute_complete_event_none(self): - """Asserts that logging occurs as expected""" - assert self.TASK.execute_complete(context=None, event=None) is None +class TestEmrStepSensorAsync: + def test_init(self): + task = EmrStepSensorAsync( + task_id="emr_step_sensor", + job_flow_id=JOB_ID, + step_id=STEP_ID, + ) + assert isinstance(task, EmrStepSensor) + assert task.deferrable is True class TestEmrJobFlowSensorAsync: @@ -140,41 +114,3 @@ def test_emr_job_flow_sensor_async_execute_complete_event_none(self): """Asserts that logging occurs as expected""" assert self.TASK.execute_complete(context=None, event=None) is None - - -class TestEmrStepSensorAsync: - TASK = EmrStepSensorAsync( - task_id="emr_step_sensor", - job_flow_id=JOB_ID, - step_id=STEP_ID, - ) - - @mock.patch(f"{MODULE}.EmrStepSensorAsync.defer") - @mock.patch(f"{MODULE}.EmrStepSensorAsync.poke", return_value=True) - def test_emr_step_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context): - """Assert task is not deferred when it receives a finish status before deferring""" - self.TASK.execute(context) - assert not mock_defer.called - - @mock.patch(f"{MODULE}.EmrStepSensorAsync.poke", return_value=False) - def test_emr_step_sensor_async(self, mock_poke, context): - """Assert execute method defer for EmrStepSensorAsync sensor""" - - with pytest.raises(TaskDeferred) as exc: - self.TASK.execute(context) - assert isinstance(exc.value.trigger, EmrStepSensorTrigger), "Trigger is not a EmrStepSensorTrigger" - - def test_emr_step_sensor_execute_complete_success(self): - """Assert execute_complete log success message when triggerer fire with target state""" - - with mock.patch.object(self.TASK.log, "info") as mock_log_info: - self.TASK.execute_complete( - context={}, event={"status": "success", "message": "Job flow currently COMPLETED"} - ) - mock_log_info.assert_called_with("%s completed successfully.", "j-T0CT8Z0C20NT") - - def test_emr_step_sensor_execute_complete_failure(self): - """Assert execute_complete method fail""" - - with pytest.raises(AirflowException): - self.TASK.execute_complete(context={}, event={"status": "error", "message": ""}) diff --git a/tests/amazon/aws/triggers/test_emr.py b/tests/amazon/aws/triggers/test_emr.py index 178693acc..39a8fff02 100644 --- a/tests/amazon/aws/triggers/test_emr.py +++ b/tests/amazon/aws/triggers/test_emr.py @@ -6,9 +6,7 @@ from astronomer.providers.amazon.aws.triggers.emr import ( EmrContainerOperatorTrigger, - EmrContainerSensorTrigger, EmrJobFlowSensorTrigger, - EmrStepSensorTrigger, ) VIRTUAL_CLUSTER_ID = "test_cluster_1" @@ -66,214 +64,6 @@ def _emr_describe_step_response(state): } -class TestEmrContainerSensorTrigger: - TRIGGER = EmrContainerSensorTrigger( - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - job_id=JOB_ID, - max_tries=MAX_RETRIES, - aws_conn_id=AWS_CONN_ID, - poll_interval=POLL_INTERVAL, - ) - - def test_emr_container_sensors_trigger_serialization(self): - """ - Asserts that the EmrContainerSensorTrigger correctly serializes its arguments - and classpath. - """ - - classpath, kwargs = self.TRIGGER.serialize() - assert classpath == "astronomer.providers.amazon.aws.triggers.emr.EmrContainerSensorTrigger" - assert kwargs == { - "virtual_cluster_id": VIRTUAL_CLUSTER_ID, - "job_id": JOB_ID, - "max_tries": MAX_RETRIES, - "poll_interval": POLL_INTERVAL, - "aws_conn_id": AWS_CONN_ID, - } - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_status", - [ - "PENDING", - "SUBMITTED", - "RUNNING", - None, - ], - ) - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") - async def test_emr_container_sensors_trigger_run(self, mock_query_status, mock_status): - """Test if the task is run is in trigger successfully.""" - mock_query_status.return_value = mock_status - - task = asyncio.create_task(self.TRIGGER.run().__anext__()) - await asyncio.sleep(0.5) - - # TriggerEvent was not returned - assert task.done() is False - asyncio.get_event_loop().stop() - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_status", - ["COMPLETED"], - ) - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") - async def test_emr_container_sensors_trigger_completed(self, mock_query_status, mock_status): - """ - Test if the task is run is in trigger failure status. - """ - mock_query_status.return_value = mock_status - - generator = self.TRIGGER.run() - response = await generator.asend(None) - msg = "EMR Containers sensors completed" - assert TriggerEvent({"status": "success", "message": msg}) == response - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_status", - ["FAILED", "CANCELLED", "CANCEL_PENDING"], - ) - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") - async def test_emr_container_sensors_trigger_failure_status(self, mock_query_status, mock_status): - """ - Test if the task is run is in trigger failure status. - """ - mock_query_status.return_value = mock_status - - generator = self.TRIGGER.run() - response = await generator.asend(None) - msg = f"EMR Containers sensor failed {mock_status}" - assert TriggerEvent({"status": "error", "message": msg}) == response - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") - async def test_emr_container_sensors_trigger_exception(self, mock_query_status): - """ - Test EMR container sensors with raise exception - """ - mock_query_status.side_effect = Exception("Test exception") - - task = [i async for i in self.TRIGGER.run()] - assert len(task) == 1 - assert TriggerEvent({"status": "error", "message": "Test exception"}) in task - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") - async def test_emr_container_sensor_trigger_timeout(self, mock_query_status): - """Asserts that the EmrContainerSensorTrigger triggers correct event in case of timeout""" - mock_query_status.return_value = "PENDING" - trigger = EmrContainerSensorTrigger( - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - job_id=JOB_ID, - aws_conn_id=AWS_CONN_ID, - poll_interval=1, - max_tries=2, - ) - generator = trigger.run() - actual = await generator.asend(None) - expected = TriggerEvent( - {"status": "error", "message": "Timeout: Maximum retry limit exceed", "job_id": JOB_ID} - ) - assert actual == expected - - -class TestEmrStepSensorTrigger: - def test_emr_step_sensor_serialization(self): - """Asserts that the EmrStepSensorTrigger correctly serializes its arguments and classpath.""" - trigger = EmrStepSensorTrigger( - job_flow_id=JOB_FLOW_ID, step_id=STEP_ID, aws_conn_id=AWS_CONN_ID, poke_interval=60 - ) - - classpath, kwargs = trigger.serialize() - assert classpath == "astronomer.providers.amazon.aws.triggers.emr.EmrStepSensorTrigger" - assert kwargs == { - "job_flow_id": JOB_FLOW_ID, - "step_id": STEP_ID, - "aws_conn_id": AWS_CONN_ID, - "poke_interval": 60, - "target_states": ["COMPLETED"], - "failed_states": ["CANCELLED", "FAILED", "INTERRUPTED"], - } - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrStepSensorHookAsync.emr_describe_step") - async def test_emr_step_sensor_trigger_run_success(self, emr_describe_step): - """Assert EmrStepSensorTrigger run method success""" - emr_describe_step.return_value = _emr_describe_step_response("COMPLETED") - trigger = EmrStepSensorTrigger( - job_flow_id=JOB_FLOW_ID, step_id=STEP_ID, aws_conn_id=AWS_CONN_ID, poke_interval=60 - ) - generator = trigger.run() - actual = await generator.asend(None) - expected = TriggerEvent({"status": "success", "message": "Job flow currently COMPLETED"}) - assert expected == actual - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_response", - [ - _emr_describe_step_response("PENDING"), - _emr_describe_step_response("CANCEL_PENDING"), - _emr_describe_step_response("RUNNING"), - _emr_describe_step_response(None), - ], - ) - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrStepSensorHookAsync.emr_describe_step") - async def test_emr_step_sensor_trigger_run_pending(self, emr_describe_step, mock_response): - """Assert run method of EmrStepSensorHookAsync sleep""" - emr_describe_step.return_value = mock_response - trigger = EmrStepSensorTrigger( - job_flow_id=JOB_FLOW_ID, step_id=STEP_ID, aws_conn_id=AWS_CONN_ID, poke_interval=5 - ) - task = asyncio.create_task(trigger.run().__anext__()) - await asyncio.sleep(0.5) - - # TriggerEvent was not returned - assert task.done() is False - asyncio.get_event_loop().stop() - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_response", - [ - _emr_describe_step_response("CANCELLED"), - _emr_describe_step_response("FAILED"), - _emr_describe_step_response("INTERRUPTED"), - ], - ) - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrStepSensorHookAsync.emr_describe_step") - async def test_emr_step_sensor_trigger_run_fail(self, emr_describe_step, mock_response): - """Assert run method of EmrStepSensorHookAsync fail""" - emr_describe_step.return_value = mock_response - trigger = EmrStepSensorTrigger( - job_flow_id=JOB_FLOW_ID, - step_id=STEP_ID, - aws_conn_id=AWS_CONN_ID, - poke_interval=5, - failed_states=["CANCELLED", "FAILED", "INTERRUPTED"], - ) - generator = trigger.run() - actual = await generator.asend(None) - expected = TriggerEvent( - {"status": "error", "message": "for reason Unknown Error with message and log file "} - ) - assert actual == expected - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrStepSensorHookAsync.emr_describe_step") - async def test_emr_step_sensor_trigger_run_failure(self, emr_describe_step): - """Test EmrStepSensorTrigger run method fail""" - emr_describe_step.side_effect = Exception("Test exception") - trigger = EmrStepSensorTrigger( - job_flow_id=JOB_FLOW_ID, step_id=STEP_ID, aws_conn_id=AWS_CONN_ID, poke_interval=60 - ) - generator = trigger.run() - actual = await generator.asend(None) - assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual - - class TestEmrJobFlowSensorTrigger: TRIGGER = EmrJobFlowSensorTrigger( job_flow_id=JOB_ID,