Skip to content

Commit

Permalink
Revert "Deprecate EmrContainerOperatorAsync (#1393)" (#1400)
Browse files Browse the repository at this point in the history
This reverts commit 60546e4.
  • Loading branch information
Lee-W authored Dec 26, 2023
1 parent daa5542 commit f1887ee
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 26 deletions.
97 changes: 85 additions & 12 deletions astronomer/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,95 @@
import warnings
from __future__ import annotations

from typing import Any

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator

from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger
from astronomer.providers.utils.typing_compat import Context


class EmrContainerOperatorAsync(EmrContainerOperator):
"""
This class is deprecated.
Please use :class: `~airflow.providers.amazon.aws.operators.emr.EmrContainerOperator`.
An async operator that submits jobs to EMR on EKS virtual clusters.
:param name: The name of the job run.
:param virtual_cluster_id: The EMR on EKS virtual cluster ID
:param execution_role_arn: The IAM role ARN associated with the job run.
:param release_label: The Amazon EMR release version to use for the job run.
:param job_driver: Job configuration details, e.g. the Spark job parameters.
:param configuration_overrides: The configuration overrides for the job run,
specifically either application configuration or monitoring configuration.
:param client_request_token: The client idempotency token of the job run request.
Use this if you want to specify a unique ID to prevent two jobs from getting started.
If no token is provided, a UUIDv4 token will be generated for you.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR
:param max_tries: Deprecated - use max_polling_attempts instead.
:param max_polling_attempts: Maximum number of times to wait for the job run to finish.
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
:param tags: The tags assigned to job runs. Defaults to None
"""

def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
warnings.warn(
(
"This module is deprecated. "
"Please use `airflow.providers.amazon.aws.operators.emr.EmrContainerOperator` "
"and set deferrable to True instead."
def execute(self, context: Context) -> str | None:
"""Deferred and give control to trigger"""
hook = EmrContainerHook(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)
job_id = hook.submit_job(
name=self.name,
execution_role_arn=self.execution_role_arn,
release_label=self.release_label,
job_driver=self.job_driver,
configuration_overrides=self.configuration_overrides,
client_request_token=self.client_request_token,
tags=self.tags,
)
try:
# for apache-airflow-providers-amazon<6.0.0
polling_attempts = self.max_tries # type: ignore[attr-defined]
except AttributeError: # pragma: no cover
# for apache-airflow-providers-amazon>=6.0.0
# max_tries is deprecated so instead of max_tries using self.max_polling_attempts
polling_attempts = self.max_polling_attempts

query_state = hook.check_query_status(job_id)
if query_state in hook.SUCCESS_STATES:
self.log.info(
f"Try : Query execution completed. Final state is {query_state}"
f"EMR Containers Operator success {query_state}"
)
return job_id

if query_state in hook.FAILURE_STATES:
error_message = self.hook.get_job_failure_reason(job_id)
raise AirflowException(
f"EMR Containers job failed. Final state is {query_state}. "
f"query_execution_id is {job_id}. Error: {error_message}"
)

self.defer(
timeout=self.execution_timeout,
trigger=EmrContainerOperatorTrigger(
virtual_cluster_id=self.virtual_cluster_id,
job_id=job_id,
aws_conn_id=self.aws_conn_id,
poll_interval=self.poll_interval,
max_tries=polling_attempts,
),
DeprecationWarning,
stacklevel=2,
method_name="execute_complete",
)
return super().__init__(*args, deferrable=True, **kwargs)

# for bypassing mypy missing return error
return None # pragma: no cover

def execute_complete(self, context: Context, event: dict[str, Any]) -> str: # type: ignore[override]
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if "status" in event and event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(event["message"])
job_id: str = event["job_id"]
return job_id
80 changes: 79 additions & 1 deletion astronomer/providers/amazon/aws/triggers/emr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
from typing import Any, AsyncIterator, Dict, Iterable, Optional, Tuple
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple

from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.amazon.aws.hooks.emr import (
EmrContainerHookAsync,
EmrJobFlowHookAsync,
)

Expand Down Expand Up @@ -36,6 +37,83 @@ def __init__(
super().__init__(**kwargs)


class EmrContainerOperatorTrigger(EmrContainerBaseTrigger):
"""Poll for the status of EMR container until reaches terminal state"""

INTERMEDIATE_STATES: List[str] = ["PENDING", "SUBMITTED", "RUNNING"]
FAILURE_STATES: List[str] = ["FAILED", "CANCELLED", "CANCEL_PENDING"]
SUCCESS_STATES: List[str] = ["COMPLETED"]
TERMINAL_STATES: List[str] = ["COMPLETED", "FAILED", "CANCELLED", "CANCEL_PENDING"]

def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes EmrContainerOperatorTrigger arguments and classpath."""
return (
"astronomer.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger",
{
"virtual_cluster_id": self.virtual_cluster_id,
"job_id": self.job_id,
"aws_conn_id": self.aws_conn_id,
"max_tries": self.max_tries,
"poll_interval": self.poll_interval,
},
)

async def run(self) -> AsyncIterator["TriggerEvent"]:
"""Run until EMR container reaches the desire state"""
hook = EmrContainerHookAsync(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id)
try:
try_number: int = 1
while True:
query_state = await hook.check_job_status(self.job_id)
if query_state is None:
self.log.info("Try %s: Invalid query state. Retrying again", try_number)
await asyncio.sleep(self.poll_interval)
elif query_state in self.FAILURE_STATES:
self.log.info(
"Try %s: Query execution completed. Final state is %s", try_number, query_state
)
error_message = await hook.get_job_failure_reason(self.job_id)
message = (
f"EMR Containers job failed. Final state is {query_state}. "
f"query_execution_id is {self.job_id}. Error: {error_message}"
)
yield TriggerEvent(
{
"status": "error",
"message": message,
"job_id": self.job_id,
}
)
elif query_state in self.SUCCESS_STATES:
self.log.info(
"Try %s: Query execution completed. Final state is %s", try_number, query_state
)
yield TriggerEvent(
{
"status": "success",
"message": f"EMR Containers Operator success {query_state}",
"job_id": self.job_id,
}
)
else:
self.log.info(
"Try %s: Query is still in non-terminal state - %s", try_number, query_state
)
await asyncio.sleep(self.poll_interval)
if self.max_tries and try_number >= self.max_tries:
yield TriggerEvent(
{
"status": "error",
"message": "Timeout: Maximum retry limit exceed",
"job_id": self.job_id,
}
)

try_number += 1
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})


class EmrJobFlowSensorTrigger(BaseTrigger):
"""
EmrJobFlowSensorTrigger is fired as deferred class with params to run the task in trigger worker, when
Expand Down
77 changes: 65 additions & 12 deletions tests/amazon/aws/operators/test_emr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
from unittest import mock

from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator
import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook

from astronomer.providers.amazon.aws.operators.emr import EmrContainerOperatorAsync
from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger

VIRTUAL_CLUSTER_ID = os.getenv("VIRTUAL_CLUSTER_ID", "test-cluster")
JOB_ROLE_ARN = os.getenv("JOB_ROLE_ARN", "arn:aws:iam::012345678912:role/emr_eks_default_role")
Expand Down Expand Up @@ -44,15 +48,64 @@


class TestEmrContainerOperatorAsync:
def test_init(self):
task = EmrContainerOperatorAsync(
task_id="start_job",
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
execution_role_arn=JOB_ROLE_ARN,
release_label="emr-6.3.0-latest",
job_driver=JOB_DRIVER_ARG,
configuration_overrides=CONFIGURATION_OVERRIDES_ARG,
name="pi.py",
@pytest.mark.parametrize("status", EmrContainerHook.SUCCESS_STATES)
@mock.patch("astronomer.providers.amazon.aws.operators.emr.EmrContainerOperatorAsync.defer")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.check_query_status")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job")
def test_emr_container_operator_async_succeeded_before_defer(
self, check_job_status, check_query_status, defer, status, context
):
check_job_status.return_value = JOB_ID
check_query_status.return_value = status
assert EMR_OPERATOR.execute(context) == JOB_ID

assert not defer.called

@pytest.mark.parametrize("status", EmrContainerHook.FAILURE_STATES)
@mock.patch("astronomer.providers.amazon.aws.operators.emr.EmrContainerOperatorAsync.defer")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_job_failure_reason")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.check_query_status")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job")
def test_emr_container_operator_async_terminal_before_defer(
self, check_job_status, check_query_status, get_job_failure_reason, defer, status, context
):
check_job_status.return_value = JOB_ID
check_query_status.return_value = status

with pytest.raises(AirflowException):
EMR_OPERATOR.execute(context)

assert not defer.called

@pytest.mark.parametrize("status", EmrContainerHook.INTERMEDIATE_STATES)
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.check_query_status")
@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job")
def test_emr_container_operator_async(self, check_job_status, check_query_status, status, context):
check_job_status.return_value = JOB_ID
check_query_status.return_value = status
with pytest.raises(TaskDeferred) as exc:
EMR_OPERATOR.execute(context)

assert isinstance(
exc.value.trigger, EmrContainerOperatorTrigger
), "Trigger is not a EmrContainerOperatorTrigger"

@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job")
def test_execute_complete_success_task(self, check_job_status):
"""Assert execute_complete succeed"""
check_job_status.return_value = JOB_ID
assert (
EMR_OPERATOR.execute_complete(
context=None, event={"status": "success", "message": "Job completed", "job_id": JOB_ID}
)
== JOB_ID
)
assert isinstance(task, EmrContainerOperator)
assert task.deferrable is True

@mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.submit_job")
def test_execute_complete_fail_task(self, check_job_status):
"""Assert execute_complete throw AirflowException"""
check_job_status.return_value = JOB_ID
with pytest.raises(AirflowException):
EMR_OPERATOR.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)
Loading

0 comments on commit f1887ee

Please sign in to comment.