diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index c5c22f97bc3e2..30a4c943c69ba 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -24,7 +24,11 @@ from deprecated import deprecated from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import ( + AirflowException, + AirflowProviderDeprecationWarning, + AirflowSkipException, +) from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri from airflow.providers.amazon.aws.triggers.emr import ( @@ -231,7 +235,9 @@ def poke(self, context: Context) -> bool: if state in EmrServerlessHook.APPLICATION_FAILURE_STATES: # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}" + failure_message = ( + f"EMR Serverless application failed: {self.failure_message_from_response(response)}" + ) if self.soft_fail: raise AirflowSkipException(failure_message) raise AirflowException(failure_message) diff --git a/tests/providers/amazon/aws/sensors/test_emr_serverless_application.py b/tests/providers/amazon/aws/sensors/test_emr_serverless_application.py new file mode 100644 index 0000000000000..c35d84e7fa5af --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_emr_serverless_application.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor + + +class TestEmrServerlessApplicationSensor: + def setup_method(self): + self.app_id = "vzwemreks" + self.job_run_id = "job1234" + self.sensor = EmrServerlessApplicationSensor( + task_id="test_emrcontainer_sensor", + application_id=self.app_id, + aws_conn_id="aws_default", + ) + + def set_get_application_return_value(self, return_value: dict[str, str]): + self.mock_hook = MagicMock() + self.mock_hook.conn.get_application.return_value = return_value + self.sensor.hook = self.mock_hook + + def assert_get_application_was_called_once_with_app_id(self): + self.mock_hook.conn.get_application.assert_called_once_with(applicationId=self.app_id) + + +class TestPokeReturnValue(TestEmrServerlessApplicationSensor): + @pytest.mark.parametrize( + "state, expected_result", + [ + ("CREATING", False), + ("STARTING", False), + ("STOPPING", False), + ("CREATED", True), + ("STARTED", True), + ], + ) + def test_poke_returns_expected_result_for_states(self, state, expected_result): + get_application_return_value = {"application": {"state": state}} + self.set_get_application_return_value(get_application_return_value) + assert self.sensor.poke(None) == expected_result + self.assert_get_application_was_called_once_with_app_id() + + +class TestPokeRaisesAirflowException(TestEmrServerlessApplicationSensor): + @pytest.mark.parametrize("state", ["STOPPED", "TERMINATED"]) + def test_poke_raises_airflow_exception_with_failure_states(self, state): + state_details = f"mock {state}" + exception_msg = f"EMR Serverless application failed: {state_details}" + get_job_run_return_value = {"application": {"state": state, "stateDetails": state_details}} + self.set_get_application_return_value(get_job_run_return_value) + + with pytest.raises(AirflowException) as ctx: + self.sensor.poke(None) + + assert exception_msg == str(ctx.value) + self.assert_get_application_was_called_once_with_app_id() + + +class TestPokeRaisesAirflowSkipException(TestEmrServerlessApplicationSensor): + def test_when_state_is_failed_and_soft_fail_is_true_poke_should_raise_skip_exception(self): + self.sensor.soft_fail = True + self.set_get_application_return_value( + {"application": {"state": "STOPPED", "stateDetails": "mock stopped"}} + ) + with pytest.raises(AirflowSkipException) as ctx: + self.sensor.poke(None) + assert "EMR Serverless application failed: mock stopped" == str(ctx.value) + self.assert_get_application_was_called_once_with_app_id() + self.sensor.soft_fail = False diff --git a/tests/providers/amazon/aws/sensors/test_emr_serverless_job.py b/tests/providers/amazon/aws/sensors/test_emr_serverless_job.py new file mode 100644 index 0000000000000..299efe3fd277e --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_emr_serverless_job.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.sensors.emr import EmrServerlessJobSensor + + +class TestEmrServerlessJobSensor: + def setup_method(self): + self.app_id = "vzwemreks" + self.job_run_id = "job1234" + self.sensor = EmrServerlessJobSensor( + task_id="test_emrcontainer_sensor", + application_id=self.app_id, + job_run_id=self.job_run_id, + aws_conn_id="aws_default", + ) + + def set_get_job_run_return_value(self, return_value: dict[str, str]): + self.mock_hook = MagicMock() + self.mock_hook.conn.get_job_run.return_value = return_value + self.sensor.hook = self.mock_hook + + def assert_get_job_run_was_called_once_with_app_and_run_id(self): + self.mock_hook.conn.get_job_run.assert_called_once_with( + applicationId=self.app_id, jobRunId=self.job_run_id + ) + + +class TestPokeReturnValue(TestEmrServerlessJobSensor): + @pytest.mark.parametrize( + "state, expected_result", + [ + ("PENDING", False), + ("RUNNING", False), + ("SCHEDULED", False), + ("SUBMITTED", False), + ("SUCCESS", True), + ], + ) + def test_poke_returns_expected_result_for_states(self, state, expected_result): + get_job_run_return_value = {"jobRun": {"state": state}} + self.set_get_job_run_return_value(get_job_run_return_value) + assert self.sensor.poke(None) == expected_result + self.assert_get_job_run_was_called_once_with_app_and_run_id() + + +class TestPokeRaisesAirflowException(TestEmrServerlessJobSensor): + @pytest.mark.parametrize("state", ["FAILED", "CANCELLING", "CANCELLED"]) + def test_poke_raises_airflow_exception_with_specified_states(self, state): + state_details = f"mock {state}" + exception_msg = f"EMR Serverless job failed: {state_details}" + get_job_run_return_value = {"jobRun": {"state": state, "stateDetails": state_details}} + self.set_get_job_run_return_value(get_job_run_return_value) + + with pytest.raises(AirflowException) as ctx: + self.sensor.poke(None) + + assert exception_msg == str(ctx.value) + self.assert_get_job_run_was_called_once_with_app_and_run_id() + + +class TestPokeRaisesAirflowSkipException(TestEmrServerlessJobSensor): + def test_when_state_is_failed_and_soft_fail_is_true_poke_should_raise_skip_exception(self): + self.sensor.soft_fail = True + self.set_get_job_run_return_value({"jobRun": {"state": "FAILED", "stateDetails": "mock failed"}}) + with pytest.raises(AirflowSkipException) as ctx: + self.sensor.poke(None) + assert "EMR Serverless job failed: mock failed" == str(ctx.value) + self.assert_get_job_run_was_called_once_with_app_and_run_id() + self.sensor.soft_fail = False