From a4ecdc910f8d20580f204f8491cc7a0534de0fae Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Tue, 26 Sep 2023 03:38:34 +0800 Subject: [PATCH] fix(providers/amazon): respect soft_fail argument when exception is raised (#34134) --- .../amazon/aws/sensors/cloud_formation.py | 15 ++++++- airflow/providers/amazon/aws/sensors/dms.py | 16 +++++--- airflow/providers/amazon/aws/sensors/ec2.py | 8 +++- airflow/providers/amazon/aws/sensors/eks.py | 13 +++--- airflow/providers/amazon/aws/sensors/emr.py | 40 +++++++++++++++---- .../providers/amazon/aws/sensors/test_eks.py | 16 ++++---- 6 files changed, 78 insertions(+), 30 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py b/airflow/providers/amazon/aws/sensors/cloud_formation.py index 942735c4bcd31..5c6b1f2246938 100644 --- a/airflow/providers/amazon/aws/sensors/cloud_formation.py +++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context +from airflow.exceptions import AirflowSkipException from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook from airflow.sensors.base import BaseSensorOperator @@ -57,7 +58,12 @@ def poke(self, context: Context): return True if stack_status in ("CREATE_IN_PROGRESS", None): return False - raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}") + + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Stack {self.stack_name} in bad state: {stack_status}" + if self.soft_fail: + raise AirflowSkipException(message) + raise ValueError(message) @cached_property def hook(self) -> CloudFormationHook: @@ -101,7 +107,12 @@ def poke(self, context: Context): return True if stack_status == "DELETE_IN_PROGRESS": return False - raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}") + + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Stack {self.stack_name} in bad state: {stack_status}" + if self.soft_fail: + raise AirflowSkipException(message) + raise ValueError(message) @cached_property def hook(self) -> CloudFormationHook: diff --git a/airflow/providers/amazon/aws/sensors/dms.py b/airflow/providers/amazon/aws/sensors/dms.py index 993459f5f21fc..d6ce3b3b1b94a 100644 --- a/airflow/providers/amazon/aws/sensors/dms.py +++ b/airflow/providers/amazon/aws/sensors/dms.py @@ -22,7 +22,7 @@ from deprecated import deprecated -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.amazon.aws.hooks.dms import DmsHook from airflow.sensors.base import BaseSensorOperator @@ -75,9 +75,11 @@ def poke(self, context: Context): status: str | None = self.hook.get_task_status(self.replication_task_arn) if not status: - raise AirflowException( - f"Failed to read task status, task with ARN {self.replication_task_arn} not found" - ) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Failed to read task status, task with ARN {self.replication_task_arn} not found" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) self.log.info("DMS Replication task (%s) has status: %s", self.replication_task_arn, status) @@ -85,7 +87,11 @@ def poke(self, context: Context): return True if status in self.termination_statuses: - raise AirflowException(f"Unexpected status: {status}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Unexpected status: {status}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) return False diff --git a/airflow/providers/amazon/aws/sensors/ec2.py b/airflow/providers/amazon/aws/sensors/ec2.py index 2b7b63f7e6c7a..cdebd1b44a078 100644 --- a/airflow/providers/amazon/aws/sensors/ec2.py +++ b/airflow/providers/amazon/aws/sensors/ec2.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger from airflow.sensors.base import BaseSensorOperator @@ -94,5 +94,9 @@ def poke(self, context: Context): def execute_complete(self, context, event=None): if event["status"] != "success": - raise AirflowException(f"Error: {event}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Error: {event}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) return diff --git a/airflow/providers/amazon/aws/sensors/eks.py b/airflow/providers/amazon/aws/sensors/eks.py index 4acf41fa03780..28f4312928460 100644 --- a/airflow/providers/amazon/aws/sensors/eks.py +++ b/airflow/providers/amazon/aws/sensors/eks.py @@ -21,7 +21,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Sequence -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.providers.amazon.aws.hooks.eks import ( ClusterStates, EksHook, @@ -53,9 +53,6 @@ NodegroupStates.NONEXISTENT, } ) -UNEXPECTED_TERMINAL_STATE_MSG = ( - "Terminal state reached. Current state: {current_state}, Expected state: {target_state}" -) class EksBaseSensor(BaseSensorOperator): @@ -109,9 +106,11 @@ def poke(self, context: Context) -> bool: self.log.info("Current state: %s", state) if state in (self.get_terminal_states() - {self.target_state}): # If we reach a terminal state which is not the target state: - raise AirflowException( - UNEXPECTED_TERMINAL_STATE_MSG.format(current_state=state, target_state=self.target_state) - ) + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Terminal state reached. Current state: {state}, Expected state: {self.target_state}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) return state == self.target_state @abstractmethod diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 2fea7509624a0..c184c28f1aa44 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -24,7 +24,7 @@ from deprecated import deprecated from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri from airflow.providers.amazon.aws.triggers.emr import ( @@ -82,7 +82,11 @@ def poke(self, context: Context): return True if state in self.failed_states: - raise AirflowException(f"EMR job failed: {self.failure_message_from_response(response)}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"EMR job failed: {self.failure_message_from_response(response)}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) return False @@ -156,6 +160,9 @@ def poke(self, context: Context) -> bool: if state in EmrServerlessHook.JOB_FAILURE_STATES: failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}" + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(failure_message) raise AirflowException(failure_message) return state in self.target_states @@ -210,7 +217,10 @@ def poke(self, context: Context) -> bool: state = response["application"]["state"] if state in EmrServerlessHook.APPLICATION_FAILURE_STATES: + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}" + if self.soft_fail: + raise AirflowSkipException(failure_message) raise AirflowException(failure_message) return state in self.target_states @@ -295,7 +305,11 @@ def poke(self, context: Context) -> bool: ) if state in self.FAILURE_STATES: - raise AirflowException("EMR Containers sensor failed") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = "EMR Containers sensor failed" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) if state in self.INTERMEDIATE_STATES: return False @@ -323,7 +337,11 @@ def execute(self, context: Context): def execute_complete(self, context, event=None): if event["status"] != "success": - raise AirflowException(f"Error while running job: {event}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Error while running job: {event}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) else: self.log.info(event["message"]) @@ -508,9 +526,13 @@ def execute(self, context: Context) -> None: method_name="execute_complete", ) - def execute_complete(self, context, event=None): + def execute_complete(self, context: Context, event=None) -> None: if event["status"] != "success": - raise AirflowException(f"Error while running job: {event}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Error while running job: {event}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) self.log.info("Job completed.") @@ -637,6 +659,10 @@ def execute(self, context: Context) -> None: def execute_complete(self, context, event=None): if event["status"] != "success": - raise AirflowException(f"Error while running job: {event}") + # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 + message = f"Error while running job: {event}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) self.log.info("Job completed.") diff --git a/tests/providers/amazon/aws/sensors/test_eks.py b/tests/providers/amazon/aws/sensors/test_eks.py index 0bb625532d9ff..9749e582f6a10 100644 --- a/tests/providers/amazon/aws/sensors/test_eks.py +++ b/tests/providers/amazon/aws/sensors/test_eks.py @@ -31,7 +31,6 @@ CLUSTER_TERMINAL_STATES, FARGATE_TERMINAL_STATES, NODEGROUP_TERMINAL_STATES, - UNEXPECTED_TERMINAL_STATE_MSG, EksClusterStateSensor, EksFargateProfileStateSensor, EksNodegroupStateSensor, @@ -75,8 +74,9 @@ def test_poke_reached_pending_state(self, mock_get_cluster_state, setUp, pending def test_poke_reached_unexpected_terminal_state( self, mock_get_cluster_state, setUp, unexpected_terminal_state ): - expected_message = UNEXPECTED_TERMINAL_STATE_MSG.format( - current_state=unexpected_terminal_state, target_state=self.target_state + expected_message = ( + f"Terminal state reached. Current state: {unexpected_terminal_state}, " + f"Expected state: {self.target_state}" ) mock_get_cluster_state.return_value = unexpected_terminal_state @@ -122,8 +122,9 @@ def test_poke_reached_pending_state(self, mock_get_fargate_profile_state, setUp, def test_poke_reached_unexpected_terminal_state( self, mock_get_fargate_profile_state, setUp, unexpected_terminal_state ): - expected_message = UNEXPECTED_TERMINAL_STATE_MSG.format( - current_state=unexpected_terminal_state, target_state=self.target_state + expected_message = ( + f"Terminal state reached. Current state: {unexpected_terminal_state}, " + f"Expected state: {self.target_state}" ) mock_get_fargate_profile_state.return_value = unexpected_terminal_state @@ -171,8 +172,9 @@ def test_poke_reached_pending_state(self, mock_get_nodegroup_state, setUp, pendi def test_poke_reached_unexpected_terminal_state( self, mock_get_nodegroup_state, setUp, unexpected_terminal_state ): - expected_message = UNEXPECTED_TERMINAL_STATE_MSG.format( - current_state=unexpected_terminal_state, target_state=self.target_state + expected_message = ( + f"Terminal state reached. Current state: {unexpected_terminal_state}, " + f"Expected state: {self.target_state}" ) mock_get_nodegroup_state.return_value = unexpected_terminal_state