diff --git a/astronomer/providers/amazon/aws/operators/emr.py b/astronomer/providers/amazon/aws/operators/emr.py index 24295e947..955011733 100644 --- a/astronomer/providers/amazon/aws/operators/emr.py +++ b/astronomer/providers/amazon/aws/operators/emr.py @@ -1,22 +1,95 @@ -import warnings +from __future__ import annotations +from typing import Any + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator +from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger +from astronomer.providers.utils.typing_compat import Context + class EmrContainerOperatorAsync(EmrContainerOperator): """ - This class is deprecated. - Please use :class: `~airflow.providers.amazon.aws.operators.emr.EmrContainerOperator`. + An async operator that submits jobs to EMR on EKS virtual clusters. + + :param name: The name of the job run. + :param virtual_cluster_id: The EMR on EKS virtual cluster ID + :param execution_role_arn: The IAM role ARN associated with the job run. + :param release_label: The Amazon EMR release version to use for the job run. + :param job_driver: Job configuration details, e.g. the Spark job parameters. + :param configuration_overrides: The configuration overrides for the job run, + specifically either application configuration or monitoring configuration. + :param client_request_token: The client idempotency token of the job run request. + Use this if you want to specify a unique ID to prevent two jobs from getting started. + If no token is provided, a UUIDv4 token will be generated for you. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR + :param max_tries: Deprecated - use max_polling_attempts instead. + :param max_polling_attempts: Maximum number of times to wait for the job run to finish. + Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. + :param tags: The tags assigned to job runs. Defaults to None """ - def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] - warnings.warn( - ( - "This module is deprecated. " - "Please use `airflow.providers.amazon.aws.operators.emr.EmrContainerOperator` " - "and set deferrable to True instead." + def execute(self, context: Context) -> str | None: + """Deferred and give control to trigger""" + hook = EmrContainerHook(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) + job_id = hook.submit_job( + name=self.name, + execution_role_arn=self.execution_role_arn, + release_label=self.release_label, + job_driver=self.job_driver, + configuration_overrides=self.configuration_overrides, + client_request_token=self.client_request_token, + tags=self.tags, + ) + try: + # for apache-airflow-providers-amazon<6.0.0 + polling_attempts = self.max_tries # type: ignore[attr-defined] + except AttributeError: # pragma: no cover + # for apache-airflow-providers-amazon>=6.0.0 + # max_tries is deprecated so instead of max_tries using self.max_polling_attempts + polling_attempts = self.max_polling_attempts + + query_state = hook.check_query_status(job_id) + if query_state in hook.SUCCESS_STATES: + self.log.info( + f"Try : Query execution completed. Final state is {query_state}" + f"EMR Containers Operator success {query_state}" + ) + return job_id + + if query_state in hook.FAILURE_STATES: + error_message = self.hook.get_job_failure_reason(job_id) + raise AirflowException( + f"EMR Containers job failed. Final state is {query_state}. " + f"query_execution_id is {job_id}. Error: {error_message}" + ) + + self.defer( + timeout=self.execution_timeout, + trigger=EmrContainerOperatorTrigger( + virtual_cluster_id=self.virtual_cluster_id, + job_id=job_id, + aws_conn_id=self.aws_conn_id, + poll_interval=self.poll_interval, + max_tries=polling_attempts, ), - DeprecationWarning, - stacklevel=2, + method_name="execute_complete", ) - return super().__init__(*args, deferrable=True, **kwargs) + + # for bypassing mypy missing return error + return None # pragma: no cover + + def execute_complete(self, context: Context, event: dict[str, Any]) -> str: # type: ignore[override] + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + if "status" in event and event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info(event["message"]) + job_id: str = event["job_id"] + return job_id diff --git a/astronomer/providers/amazon/aws/triggers/emr.py b/astronomer/providers/amazon/aws/triggers/emr.py index 9f9d06a66..7ca1f9d32 100644 --- a/astronomer/providers/amazon/aws/triggers/emr.py +++ b/astronomer/providers/amazon/aws/triggers/emr.py @@ -1,9 +1,10 @@ import asyncio -from typing import Any, AsyncIterator, Dict, Iterable, Optional, Tuple +from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple from airflow.triggers.base import BaseTrigger, TriggerEvent from astronomer.providers.amazon.aws.hooks.emr import ( + EmrContainerHookAsync, EmrJobFlowHookAsync, ) @@ -36,6 +37,83 @@ def __init__( super().__init__(**kwargs) +class EmrContainerOperatorTrigger(EmrContainerBaseTrigger): + """Poll for the status of EMR container until reaches terminal state""" + + INTERMEDIATE_STATES: List[str] = ["PENDING", "SUBMITTED", "RUNNING"] + FAILURE_STATES: List[str] = ["FAILED", "CANCELLED", "CANCEL_PENDING"] + SUCCESS_STATES: List[str] = ["COMPLETED"] + TERMINAL_STATES: List[str] = ["COMPLETED", "FAILED", "CANCELLED", "CANCEL_PENDING"] + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + """Serializes EmrContainerOperatorTrigger arguments and classpath.""" + return ( + "astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger", + { + "virtual_cluster_id": self.virtual_cluster_id, + "job_id": self.job_id, + "aws_conn_id": self.aws_conn_id, + "max_tries": self.max_tries, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator["TriggerEvent"]: + """Run until EMR container reaches the desire state""" + hook = EmrContainerHookAsync(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) + try: + try_number: int = 1 + while True: + query_state = await hook.check_job_status(self.job_id) + if query_state is None: + self.log.info("Try %s: Invalid query state. Retrying again", try_number) + await asyncio.sleep(self.poll_interval) + elif query_state in self.FAILURE_STATES: + self.log.info( + "Try %s: Query execution completed. Final state is %s", try_number, query_state + ) + error_message = await hook.get_job_failure_reason(self.job_id) + message = ( + f"EMR Containers job failed. Final state is {query_state}. " + f"query_execution_id is {self.job_id}. Error: {error_message}" + ) + yield TriggerEvent( + { + "status": "error", + "message": message, + "job_id": self.job_id, + } + ) + elif query_state in self.SUCCESS_STATES: + self.log.info( + "Try %s: Query execution completed. Final state is %s", try_number, query_state + ) + yield TriggerEvent( + { + "status": "success", + "message": f"EMR Containers Operator success {query_state}", + "job_id": self.job_id, + } + ) + else: + self.log.info( + "Try %s: Query is still in non-terminal state - %s", try_number, query_state + ) + await asyncio.sleep(self.poll_interval) + if self.max_tries and try_number >= self.max_tries: + yield TriggerEvent( + { + "status": "error", + "message": "Timeout: Maximum retry limit exceed", + "job_id": self.job_id, + } + ) + + try_number += 1 + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) + + class EmrJobFlowSensorTrigger(BaseTrigger): """ EmrJobFlowSensorTrigger is fired as deferred class with params to run the task in trigger worker, when diff --git a/tests/amazon/aws/operators/test_emr.py b/tests/amazon/aws/operators/test_emr.py index c7bc77fcc..28b38bf25 100644 --- a/tests/amazon/aws/operators/test_emr.py +++ b/tests/amazon/aws/operators/test_emr.py @@ -1,8 +1,12 @@ import os +from unittest import mock -from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator +import pytest +from airflow.exceptions import AirflowException, TaskDeferred +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from astronomer.providers.amazon.aws.operators.emr import EmrContainerOperatorAsync +from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger VIRTUAL_CLUSTER_ID = os.getenv("VIRTUAL_CLUSTER_ID", "test-cluster") JOB_ROLE_ARN = os.getenv("JOB_ROLE_ARN", "arn:aws:iam::012345678912:role/emr_eks_default_role") @@ -44,15 +48,64 @@ class TestEmrContainerOperatorAsync: - def test_init(self): - task = EmrContainerOperatorAsync( - task_id="start_job", - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - execution_role_arn=JOB_ROLE_ARN, - release_label="emr-6.3.0-latest", - job_driver=JOB_DRIVER_ARG, - configuration_overrides=CONFIGURATION_OVERRIDES_ARG, - name="pi.py", + @pytest.mark.parametrize("status", EmrContainerHook.SUCCESS_STATES) + @mock.patch("astronomer.providers.amazon.aws.operators.emr.EmrContainerOperatorAsync.defer") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.check_query_status") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job") + def test_emr_container_operator_async_succeeded_before_defer( + self, check_job_status, check_query_status, defer, status, context + ): + check_job_status.return_value = JOB_ID + check_query_status.return_value = status + assert EMR_OPERATOR.execute(context) == JOB_ID + + assert not defer.called + + @pytest.mark.parametrize("status", EmrContainerHook.FAILURE_STATES) + @mock.patch("astronomer.providers.amazon.aws.operators.emr.EmrContainerOperatorAsync.defer") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_job_failure_reason") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.check_query_status") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job") + def test_emr_container_operator_async_terminal_before_defer( + self, check_job_status, check_query_status, get_job_failure_reason, defer, status, context + ): + check_job_status.return_value = JOB_ID + check_query_status.return_value = status + + with pytest.raises(AirflowException): + EMR_OPERATOR.execute(context) + + assert not defer.called + + @pytest.mark.parametrize("status", EmrContainerHook.INTERMEDIATE_STATES) + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.check_query_status") + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job") + def test_emr_container_operator_async(self, check_job_status, check_query_status, status, context): + check_job_status.return_value = JOB_ID + check_query_status.return_value = status + with pytest.raises(TaskDeferred) as exc: + EMR_OPERATOR.execute(context) + + assert isinstance( + exc.value.trigger, EmrContainerOperatorTrigger + ), "Trigger is not a EmrContainerOperatorTrigger" + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job") + def test_execute_complete_success_task(self, check_job_status): + """Assert execute_complete succeed""" + check_job_status.return_value = JOB_ID + assert ( + EMR_OPERATOR.execute_complete( + context=None, event={"status": "success", "message": "Job completed", "job_id": JOB_ID} + ) + == JOB_ID ) - assert isinstance(task, EmrContainerOperator) - assert task.deferrable is True + + @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job") + def test_execute_complete_fail_task(self, check_job_status): + """Assert execute_complete throw AirflowException""" + check_job_status.return_value = JOB_ID + with pytest.raises(AirflowException): + EMR_OPERATOR.execute_complete( + context=None, event={"status": "error", "message": "test failure message"} + ) diff --git a/tests/amazon/aws/triggers/test_emr.py b/tests/amazon/aws/triggers/test_emr.py index cd976192d..39a8fff02 100644 --- a/tests/amazon/aws/triggers/test_emr.py +++ b/tests/amazon/aws/triggers/test_emr.py @@ -4,7 +4,10 @@ import pytest from airflow.triggers.base import TriggerEvent -from astronomer.providers.amazon.aws.triggers.emr import EmrJobFlowSensorTrigger +from astronomer.providers.amazon.aws.triggers.emr import ( + EmrContainerOperatorTrigger, + EmrJobFlowSensorTrigger, +) VIRTUAL_CLUSTER_ID = "test_cluster_1" JOB_ID = "jobid-12122" @@ -149,3 +152,112 @@ async def test_emr_job_flow_sensors_trigger_exception(self, mock_cluster_detail) task = [i async for i in self.TRIGGER.run()] assert len(task) == 1 assert TriggerEvent({"status": "error", "message": "Test exception"}) in task + + +class TestEmrContainerOperatorTrigger: + TRIGGER = EmrContainerOperatorTrigger( + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + name=NAME, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + max_tries=MAX_RETRIES, + poll_interval=POLL_INTERVAL, + ) + + def test_emr_container_operator_trigger_serialization(self): + """Asserts EmrContainerOperatorTrigger correctly serializes its arguments and classpath.""" + + classpath, kwargs = self.TRIGGER.serialize() + assert classpath == "astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger" + assert kwargs == { + "virtual_cluster_id": VIRTUAL_CLUSTER_ID, + "job_id": JOB_ID, + "aws_conn_id": AWS_CONN_ID, + "poll_interval": POLL_INTERVAL, + "max_tries": MAX_RETRIES, + } + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_status", + [ + "PENDING", + "SUBMITTED", + "RUNNING", + None, + ], + ) + @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") + async def test_emr_container_operator_trigger_run(self, mock_query_status, mock_status): + """Assert EmrContainerOperatorTrigger task run in trigger and sleep if state is intermediate""" + mock_query_status.return_value = mock_status + + task = asyncio.create_task(self.TRIGGER.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") + async def test_emr_container_operator_trigger_completed(self, mock_query_status): + """Assert EmrContainerOperatorTrigger succeed.""" + mock_query_status.return_value = "COMPLETED" + + generator = self.TRIGGER.run() + actual = await generator.asend(None) + msg = "EMR Containers Operator success COMPLETED" + assert TriggerEvent({"status": "success", "message": msg, "job_id": JOB_ID}) == actual + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "mock_status", + ["FAILED", "CANCELLED", "CANCEL_PENDING"], + ) + @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.get_job_failure_reason") + @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") + async def test_emr_container_operator_trigger_failure_status( + self, mock_query_status, mock_failure_reason, mock_status + ): + """Assert EmrContainerOperatorTrigger failed.""" + mock_query_status.return_value = mock_status + mock_failure_reason.return_value = None + + generator = self.TRIGGER.run() + actual = await generator.asend(None) + message = ( + f"EMR Containers job failed. Final state is {mock_status}. " + f"query_execution_id is {JOB_ID}. Error: {None}" + ) + assert TriggerEvent({"status": "error", "message": message, "job_id": JOB_ID}) == actual + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") + async def test_emr_container_operator_trigger_exception(self, mock_query_status): + """Assert EmrContainerOperatorTrigger raise exception""" + mock_query_status.side_effect = Exception("Test exception") + + task = [i async for i in self.TRIGGER.run()] + assert len(task) == 1 + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task + + @pytest.mark.asyncio + @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") + async def test_emr_container_operator_trigger_timeout(self, mock_query_status): + """Assert EmrContainerOperatorTrigger max_tries exceed""" + mock_query_status.return_value = "PENDING" + trigger = EmrContainerOperatorTrigger( + name=NAME, + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=1, + max_tries=2, + ) + generator = trigger.run() + actual = await generator.asend(None) + expected = TriggerEvent( + {"status": "error", "message": "Timeout: Maximum retry limit exceed", "job_id": JOB_ID} + ) + assert actual == expected