-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
This reverts commit 60546e4.
- Loading branch information
Showing
4 changed files
with
342 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,95 @@ | ||
import warnings | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from airflow.exceptions import AirflowException | ||
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook | ||
from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator | ||
|
||
from astronomer.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger | ||
from astronomer.providers.utils.typing_compat import Context | ||
|
||
|
||
class EmrContainerOperatorAsync(EmrContainerOperator): | ||
""" | ||
This class is deprecated. | ||
Please use :class: `~airflow.providers.amazon.aws.operators.emr.EmrContainerOperator`. | ||
An async operator that submits jobs to EMR on EKS virtual clusters. | ||
:param name: The name of the job run. | ||
:param virtual_cluster_id: The EMR on EKS virtual cluster ID | ||
:param execution_role_arn: The IAM role ARN associated with the job run. | ||
:param release_label: The Amazon EMR release version to use for the job run. | ||
:param job_driver: Job configuration details, e.g. the Spark job parameters. | ||
:param configuration_overrides: The configuration overrides for the job run, | ||
specifically either application configuration or monitoring configuration. | ||
:param client_request_token: The client idempotency token of the job run request. | ||
Use this if you want to specify a unique ID to prevent two jobs from getting started. | ||
If no token is provided, a UUIDv4 token will be generated for you. | ||
:param aws_conn_id: The Airflow connection used for AWS credentials. | ||
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR | ||
:param max_tries: Deprecated - use max_polling_attempts instead. | ||
:param max_polling_attempts: Maximum number of times to wait for the job run to finish. | ||
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. | ||
:param tags: The tags assigned to job runs. Defaults to None | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] | ||
warnings.warn( | ||
( | ||
"This module is deprecated. " | ||
"Please use `airflow.providers.amazon.aws.operators.emr.EmrContainerOperator` " | ||
"and set deferrable to True instead." | ||
def execute(self, context: Context) -> str | None: | ||
"""Deferred and give control to trigger""" | ||
hook = EmrContainerHook(aws_conn_id=self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) | ||
job_id = hook.submit_job( | ||
name=self.name, | ||
execution_role_arn=self.execution_role_arn, | ||
release_label=self.release_label, | ||
job_driver=self.job_driver, | ||
configuration_overrides=self.configuration_overrides, | ||
client_request_token=self.client_request_token, | ||
tags=self.tags, | ||
) | ||
try: | ||
# for apache-airflow-providers-amazon<6.0.0 | ||
polling_attempts = self.max_tries # type: ignore[attr-defined] | ||
except AttributeError: # pragma: no cover | ||
# for apache-airflow-providers-amazon>=6.0.0 | ||
# max_tries is deprecated so instead of max_tries using self.max_polling_attempts | ||
polling_attempts = self.max_polling_attempts | ||
|
||
query_state = hook.check_query_status(job_id) | ||
if query_state in hook.SUCCESS_STATES: | ||
self.log.info( | ||
f"Try : Query execution completed. Final state is {query_state}" | ||
f"EMR Containers Operator success {query_state}" | ||
) | ||
return job_id | ||
|
||
if query_state in hook.FAILURE_STATES: | ||
error_message = self.hook.get_job_failure_reason(job_id) | ||
raise AirflowException( | ||
f"EMR Containers job failed. Final state is {query_state}. " | ||
f"query_execution_id is {job_id}. Error: {error_message}" | ||
) | ||
|
||
self.defer( | ||
timeout=self.execution_timeout, | ||
trigger=EmrContainerOperatorTrigger( | ||
virtual_cluster_id=self.virtual_cluster_id, | ||
job_id=job_id, | ||
aws_conn_id=self.aws_conn_id, | ||
poll_interval=self.poll_interval, | ||
max_tries=polling_attempts, | ||
), | ||
DeprecationWarning, | ||
stacklevel=2, | ||
method_name="execute_complete", | ||
) | ||
return super().__init__(*args, deferrable=True, **kwargs) | ||
|
||
# for bypassing mypy missing return error | ||
return None # pragma: no cover | ||
|
||
def execute_complete(self, context: Context, event: dict[str, Any]) -> str: # type: ignore[override] | ||
""" | ||
Callback for when the trigger fires - returns immediately. | ||
Relies on trigger to throw an exception, otherwise it assumes execution was | ||
successful. | ||
""" | ||
if "status" in event and event["status"] == "error": | ||
raise AirflowException(event["message"]) | ||
self.log.info(event["message"]) | ||
job_id: str = event["job_id"] | ||
return job_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.