From 03a0b7267215ea2ac1bce6c60eca1a41f747e84b Mon Sep 17 00:00:00 2001 From: Sangseung Lee Date: Fri, 17 Nov 2023 18:38:51 +0900 Subject: [PATCH] Added name field to template_fields in EmrServerlessStartJobOperator (#35648) * Added name field to template_fields in EmrServerlessStartJobOperator * Moved default name setting operation from constructor to execute method * Update EmrServerlessStartJobOperator test class --- airflow/providers/amazon/aws/operators/emr.py | 4 +++- .../aws/operators/test_emr_serverless.py | 21 ++++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index ea1ec0496a002..1067464474d0c 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1173,6 +1173,7 @@ class EmrServerlessStartJobOperator(BaseOperator): "execution_role_arn", "job_driver", "configuration_overrides", + "name", ) template_fields_renderers = { @@ -1226,7 +1227,7 @@ def __init__( self.configuration_overrides = configuration_overrides self.wait_for_completion = wait_for_completion self.config = config or {} - self.name = name or self.config.pop("name", f"emr_serverless_job_airflow_{uuid4()}") + self.name = name self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type] self.waiter_delay = int(waiter_delay) # type: ignore[arg-type] self.job_id: str | None = None @@ -1268,6 +1269,7 @@ def execute(self, context: Context, event: dict[str, Any] | None = None) -> str status_args=["application.state", "application.stateDetails"], ) self.log.info("Starting job on Application: %s", self.application_id) + self.name = self.name or self.config.pop("name", f"emr_serverless_job_airflow_{uuid4()}") response = self.hook.conn.start_job_run( clientToken=self.client_request_token, applicationId=self.application_id, diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index f72cebbfec760..edb2ddc0f922c 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -375,8 +375,8 @@ def test_job_run_app_started(self, mock_conn, mock_get_waiter): job_driver=job_driver, configuration_overrides=configuration_overrides, ) - default_name = operator.name id = operator.execute(None) + default_name = operator.name assert operator.wait_for_completion is True mock_conn.get_application.assert_called_once_with(applicationId=application_id) @@ -413,11 +413,12 @@ def test_job_run_job_failed(self, mock_conn, mock_get_waiter): job_driver=job_driver, configuration_overrides=configuration_overrides, ) - default_name = operator.name with pytest.raises(AirflowException) as ex_message: id = operator.execute(None) assert id == job_run_id assert "Serverless Job failed:" in str(ex_message.value) + default_name = operator.name + mock_conn.get_application.assert_called_once_with(applicationId=application_id) mock_conn.start_job_run.assert_called_once_with( clientToken=client_request_token, @@ -446,9 +447,8 @@ def test_job_run_app_not_started(self, mock_conn, mock_get_waiter): job_driver=job_driver, configuration_overrides=configuration_overrides, ) - default_name = operator.name - id = operator.execute(None) + default_name = operator.name assert operator.wait_for_completion is True mock_conn.get_application.assert_called_once_with(applicationId=application_id) @@ -516,8 +516,8 @@ def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_ge configuration_overrides=configuration_overrides, wait_for_completion=False, ) - default_name = operator.name id = operator.execute(None) + default_name = operator.name mock_conn.get_application.assert_called_once_with(applicationId=application_id) mock_get_waiter().wait.assert_called_once() @@ -550,9 +550,10 @@ def test_job_run_app_started_no_wait_for_completion(self, mock_conn, mock_get_wa configuration_overrides=configuration_overrides, wait_for_completion=False, ) - default_name = operator.name id = operator.execute(None) assert id == job_run_id + default_name = operator.name + mock_conn.start_job_run.assert_called_once_with( clientToken=client_request_token, applicationId=application_id, @@ -581,11 +582,11 @@ def test_failed_start_job_run(self, mock_conn, mock_get_waiter): job_driver=job_driver, configuration_overrides=configuration_overrides, ) - default_name = operator.name with pytest.raises(AirflowException) as ex_message: operator.execute(None) - assert "EMR serverless job failed to start:" in str(ex_message.value) + default_name = operator.name + mock_conn.get_application.assert_called_once_with(applicationId=application_id) mock_get_waiter().wait.assert_called_once() mock_conn.start_job_run.assert_called_once_with( @@ -619,11 +620,11 @@ def test_start_job_run_fail_on_wait_for_completion(self, mock_conn, mock_get_wai job_driver=job_driver, configuration_overrides=configuration_overrides, ) - default_name = operator.name with pytest.raises(AirflowException) as ex_message: operator.execute(None) - assert "Serverless Job failed:" in str(ex_message.value) + default_name = operator.name + mock_conn.get_application.call_count == 2 mock_conn.start_job_run.assert_called_once_with( clientToken=client_request_token,