diff --git a/astronomer/providers/amazon/aws/operators/emr.py b/astronomer/providers/amazon/aws/operators/emr.py index 955011733..24295e947 100644 --- a/astronomer/providers/amazon/aws/operators/emr.py +++ b/astronomer/providers/amazon/aws/operators/emr.py @@ -1,95 +1,22 @@ -from __future__ import annotations +import warnings -from typing import Any - -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator -from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger -from astronomer.providers.utils.typing_compat import Context - class EmrContainerOperatorAsync(EmrContainerOperator): """ - An async operator that submits jobs to EMR on EKS virtual clusters. - - :param name: The name of the job run. - :param virtual_cluster_id: The EMR on EKS virtual cluster ID - :param execution_role_arn: The IAM role ARN associated with the job run. - :param release_label: The Amazon EMR release version to use for the job run. - :param job_driver: Job configuration details, e.g. the Spark job parameters. - :param configuration_overrides: The configuration overrides for the job run, - specifically either application configuration or monitoring configuration. - :param client_request_token: The client idempotency token of the job run request. - Use this if you want to specify a unique ID to prevent two jobs from getting started. - If no token is provided, a UUIDv4 token will be generated for you. - :param aws_conn_id: The Airflow connection used for AWS credentials. - :param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR - :param max_tries: Deprecated - use max_polling_attempts instead. - :param max_polling_attempts: Maximum number of times to wait for the job run to finish. - Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. - :param tags: The tags assigned to job runs. Defaults to None + This class is deprecated. + Please use :class: `~airflow.providers.amazon.aws.operators.emr.EmrContainerOperator`. """ - def execute(self, context: Context) -> str | None: - """Deferred and give control to trigger""" - hook = EmrContainerHook(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) - job_id = hook.submit_job( - name=self.name, - execution_role_arn=self.execution_role_arn, - release_label=self.release_label, - job_driver=self.job_driver, - configuration_overrides=self.configuration_overrides, - client_request_token=self.client_request_token, - tags=self.tags, - ) - try: - # for apache-airflow-providers-amazon<6.0.0 - polling_attempts = self.max_tries # type: ignore[attr-defined] - except AttributeError: # pragma: no cover - # for apache-airflow-providers-amazon>=6.0.0 - # max_tries is deprecated so instead of max_tries using self.max_polling_attempts - polling_attempts = self.max_polling_attempts - - query_state = hook.check_query_status(job_id) - if query_state in hook.SUCCESS_STATES: - self.log.info( - f"Try : Query execution completed. Final state is {query_state}" - f"EMR Containers Operator success {query_state}" - ) - return job_id - - if query_state in hook.FAILURE_STATES: - error_message = self.hook.get_job_failure_reason(job_id) - raise AirflowException( - f"EMR Containers job failed. Final state is {query_state}. " - f"query_execution_id is {job_id}. Error: {error_message}" - ) - - self.defer( - timeout=self.execution_timeout, - trigger=EmrContainerOperatorTrigger( - virtual_cluster_id=self.virtual_cluster_id, - job_id=job_id, - aws_conn_id=self.aws_conn_id, - poll_interval=self.poll_interval, - max_tries=polling_attempts, + def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] + warnings.warn( + ( + "This module is deprecated. " + "Please use `airflow.providers.amazon.aws.operators.emr.EmrContainerOperator` " + "and set deferrable to True instead." ), - method_name="execute_complete", + DeprecationWarning, + stacklevel=2, ) - - # for bypassing mypy missing return error - return None # pragma: no cover - - def execute_complete(self, context: Context, event: dict[str, Any]) -> str: # type: ignore[override] - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if "status" in event and event["status"] == "error": - raise AirflowException(event["message"]) - self.log.info(event["message"]) - job_id: str = event["job_id"] - return job_id + return super().__init__(*args, deferrable=True, **kwargs) diff --git a/astronomer/providers/amazon/aws/triggers/emr.py b/astronomer/providers/amazon/aws/triggers/emr.py index 7ca1f9d32..9f9d06a66 100644 --- a/astronomer/providers/amazon/aws/triggers/emr.py +++ b/astronomer/providers/amazon/aws/triggers/emr.py @@ -1,10 +1,9 @@ import asyncio -from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple +from typing import Any, AsyncIterator, Dict, Iterable, Optional, Tuple from airflow.triggers.base import BaseTrigger, TriggerEvent from astronomer.providers.amazon.aws.hooks.emr import ( - EmrContainerHookAsync, EmrJobFlowHookAsync, ) @@ -37,83 +36,6 @@ def __init__( super().__init__(**kwargs) -class EmrContainerOperatorTrigger(EmrContainerBaseTrigger): - """Poll for the status of EMR container until reaches terminal state""" - - INTERMEDIATE_STATES: List[str] = ["PENDING", "SUBMITTED", "RUNNING"] - FAILURE_STATES: List[str] = ["FAILED", "CANCELLED", "CANCEL_PENDING"] - SUCCESS_STATES: List[str] = ["COMPLETED"] - TERMINAL_STATES: List[str] = ["COMPLETED", "FAILED", "CANCELLED", "CANCEL_PENDING"] - - def serialize(self) -> Tuple[str, Dict[str, Any]]: - """Serializes EmrContainerOperatorTrigger arguments and classpath.""" - return ( - "astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger", - { - "virtual_cluster_id": self.virtual_cluster_id, - "job_id": self.job_id, - "aws_conn_id": self.aws_conn_id, - "max_tries": self.max_tries, - "poll_interval": self.poll_interval, - }, - ) - - async def run(self) -> AsyncIterator["TriggerEvent"]: - """Run until EMR container reaches the desire state""" - hook = EmrContainerHookAsync(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) - try: - try_number: int = 1 - while True: - query_state = await hook.check_job_status(self.job_id) - if query_state is None: - self.log.info("Try %s: Invalid query state. Retrying again", try_number) - await asyncio.sleep(self.poll_interval) - elif query_state in self.FAILURE_STATES: - self.log.info( - "Try %s: Query execution completed. Final state is %s", try_number, query_state - ) - error_message = await hook.get_job_failure_reason(self.job_id) - message = ( - f"EMR Containers job failed. Final state is {query_state}. " - f"query_execution_id is {self.job_id}. Error: {error_message}" - ) - yield TriggerEvent( - { - "status": "error", - "message": message, - "job_id": self.job_id, - } - ) - elif query_state in self.SUCCESS_STATES: - self.log.info( - "Try %s: Query execution completed. Final state is %s", try_number, query_state - ) - yield TriggerEvent( - { - "status": "success", - "message": f"EMR Containers Operator success {query_state}", - "job_id": self.job_id, - } - ) - else: - self.log.info( - "Try %s: Query is still in non-terminal state - %s", try_number, query_state - ) - await asyncio.sleep(self.poll_interval) - if self.max_tries and try_number >= self.max_tries: - yield TriggerEvent( - { - "status": "error", - "message": "Timeout: Maximum retry limit exceed", - "job_id": self.job_id, - } - ) - - try_number += 1 - except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e)}) - - class EmrJobFlowSensorTrigger(BaseTrigger): """ EmrJobFlowSensorTrigger is fired as deferred class with params to run the task in trigger worker, when diff --git a/tests/amazon/aws/operators/test_emr.py b/tests/amazon/aws/operators/test_emr.py index 28b38bf25..c7bc77fcc 100644 --- a/tests/amazon/aws/operators/test_emr.py +++ b/tests/amazon/aws/operators/test_emr.py @@ -1,12 +1,8 @@ import os -from unittest import mock -import pytest -from airflow.exceptions import AirflowException, TaskDeferred -from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook +from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator from astronomer.providers.amazon.aws.operators.emr import EmrContainerOperatorAsync -from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger VIRTUAL_CLUSTER_ID = os.getenv("VIRTUAL_CLUSTER_ID", "test-cluster") JOB_ROLE_ARN = os.getenv("JOB_ROLE_ARN", "arn:aws:iam::012345678912:role/emr_eks_default_role") @@ -48,64 +44,15 @@ class TestEmrContainerOperatorAsync: - @pytest.mark.parametrize("status", EmrContainerHook.SUCCESS_STATES) - @mock.patch("astronomer.providers.amazon.aws.operators.emr.EmrContainerOperatorAsync.defer") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.check_query_status") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job") - def test_emr_container_operator_async_succeeded_before_defer( - self, check_job_status, check_query_status, defer, status, context - ): - check_job_status.return_value = JOB_ID - check_query_status.return_value = status - assert EMR_OPERATOR.execute(context) == JOB_ID - - assert not defer.called - - @pytest.mark.parametrize("status", EmrContainerHook.FAILURE_STATES) - @mock.patch("astronomer.providers.amazon.aws.operators.emr.EmrContainerOperatorAsync.defer") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_job_failure_reason") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.check_query_status") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job") - def test_emr_container_operator_async_terminal_before_defer( - self, check_job_status, check_query_status, get_job_failure_reason, defer, status, context - ): - check_job_status.return_value = JOB_ID - check_query_status.return_value = status - - with pytest.raises(AirflowException): - EMR_OPERATOR.execute(context) - - assert not defer.called - - @pytest.mark.parametrize("status", EmrContainerHook.INTERMEDIATE_STATES) - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.check_query_status") - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job") - def test_emr_container_operator_async(self, check_job_status, check_query_status, status, context): - check_job_status.return_value = JOB_ID - check_query_status.return_value = status - with pytest.raises(TaskDeferred) as exc: - EMR_OPERATOR.execute(context) - - assert isinstance( - exc.value.trigger, EmrContainerOperatorTrigger - ), "Trigger is not a EmrContainerOperatorTrigger" - - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job") - def test_execute_complete_success_task(self, check_job_status): - """Assert execute_complete succeed""" - check_job_status.return_value = JOB_ID - assert ( - EMR_OPERATOR.execute_complete( - context=None, event={"status": "success", "message": "Job completed", "job_id": JOB_ID} - ) - == JOB_ID + def test_init(self): + task = EmrContainerOperatorAsync( + task_id="start_job", + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + execution_role_arn=JOB_ROLE_ARN, + release_label="emr-6.3.0-latest", + job_driver=JOB_DRIVER_ARG, + configuration_overrides=CONFIGURATION_OVERRIDES_ARG, + name="pi.py", ) - - @mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job") - def test_execute_complete_fail_task(self, check_job_status): - """Assert execute_complete throw AirflowException""" - check_job_status.return_value = JOB_ID - with pytest.raises(AirflowException): - EMR_OPERATOR.execute_complete( - context=None, event={"status": "error", "message": "test failure message"} - ) + assert isinstance(task, EmrContainerOperator) + assert task.deferrable is True diff --git a/tests/amazon/aws/triggers/test_emr.py b/tests/amazon/aws/triggers/test_emr.py index 39a8fff02..cd976192d 100644 --- a/tests/amazon/aws/triggers/test_emr.py +++ b/tests/amazon/aws/triggers/test_emr.py @@ -4,10 +4,7 @@ import pytest from airflow.triggers.base import TriggerEvent -from astronomer.providers.amazon.aws.triggers.emr import ( - EmrContainerOperatorTrigger, - EmrJobFlowSensorTrigger, -) +from astronomer.providers.amazon.aws.triggers.emr import EmrJobFlowSensorTrigger VIRTUAL_CLUSTER_ID = "test_cluster_1" JOB_ID = "jobid-12122" @@ -152,112 +149,3 @@ async def test_emr_job_flow_sensors_trigger_exception(self, mock_cluster_detail) task = [i async for i in self.TRIGGER.run()] assert len(task) == 1 assert TriggerEvent({"status": "error", "message": "Test exception"}) in task - - -class TestEmrContainerOperatorTrigger: - TRIGGER = EmrContainerOperatorTrigger( - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - name=NAME, - job_id=JOB_ID, - aws_conn_id=AWS_CONN_ID, - max_tries=MAX_RETRIES, - poll_interval=POLL_INTERVAL, - ) - - def test_emr_container_operator_trigger_serialization(self): - """Asserts EmrContainerOperatorTrigger correctly serializes its arguments and classpath.""" - - classpath, kwargs = self.TRIGGER.serialize() - assert classpath == "astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger" - assert kwargs == { - "virtual_cluster_id": VIRTUAL_CLUSTER_ID, - "job_id": JOB_ID, - "aws_conn_id": AWS_CONN_ID, - "poll_interval": POLL_INTERVAL, - "max_tries": MAX_RETRIES, - } - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_status", - [ - "PENDING", - "SUBMITTED", - "RUNNING", - None, - ], - ) - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") - async def test_emr_container_operator_trigger_run(self, mock_query_status, mock_status): - """Assert EmrContainerOperatorTrigger task run in trigger and sleep if state is intermediate""" - mock_query_status.return_value = mock_status - - task = asyncio.create_task(self.TRIGGER.run().__anext__()) - await asyncio.sleep(0.5) - - # TriggerEvent was not returned - assert task.done() is False - asyncio.get_event_loop().stop() - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") - async def test_emr_container_operator_trigger_completed(self, mock_query_status): - """Assert EmrContainerOperatorTrigger succeed.""" - mock_query_status.return_value = "COMPLETED" - - generator = self.TRIGGER.run() - actual = await generator.asend(None) - msg = "EMR Containers Operator success COMPLETED" - assert TriggerEvent({"status": "success", "message": msg, "job_id": JOB_ID}) == actual - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "mock_status", - ["FAILED", "CANCELLED", "CANCEL_PENDING"], - ) - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.get_job_failure_reason") - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") - async def test_emr_container_operator_trigger_failure_status( - self, mock_query_status, mock_failure_reason, mock_status - ): - """Assert EmrContainerOperatorTrigger failed.""" - mock_query_status.return_value = mock_status - mock_failure_reason.return_value = None - - generator = self.TRIGGER.run() - actual = await generator.asend(None) - message = ( - f"EMR Containers job failed. Final state is {mock_status}. " - f"query_execution_id is {JOB_ID}. Error: {None}" - ) - assert TriggerEvent({"status": "error", "message": message, "job_id": JOB_ID}) == actual - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") - async def test_emr_container_operator_trigger_exception(self, mock_query_status): - """Assert EmrContainerOperatorTrigger raise exception""" - mock_query_status.side_effect = Exception("Test exception") - - task = [i async for i in self.TRIGGER.run()] - assert len(task) == 1 - assert TriggerEvent({"status": "error", "message": "Test exception"}) in task - - @pytest.mark.asyncio - @mock.patch("astronomer.providers.amazon.aws.hooks.emr.EmrContainerHookAsync.check_job_status") - async def test_emr_container_operator_trigger_timeout(self, mock_query_status): - """Assert EmrContainerOperatorTrigger max_tries exceed""" - mock_query_status.return_value = "PENDING" - trigger = EmrContainerOperatorTrigger( - name=NAME, - virtual_cluster_id=VIRTUAL_CLUSTER_ID, - job_id=JOB_ID, - aws_conn_id=AWS_CONN_ID, - poll_interval=1, - max_tries=2, - ) - generator = trigger.run() - actual = await generator.asend(None) - expected = TriggerEvent( - {"status": "error", "message": "Timeout: Maximum retry limit exceed", "job_id": JOB_ID} - ) - assert actual == expected