Skip to content

Commit

Permalink
Added name field to template_fields in EmrServerlessStartJobOperator (a…
Browse files Browse the repository at this point in the history
…pache#35648)

* Added name field to template_fields in EmrServerlessStartJobOperator

* Moved default name setting operation from constructor to execute method

* Update EmrServerlessStartJobOperator test class
  • Loading branch information
sseung00921 authored Nov 17, 2023
1 parent 46c0f85 commit 03a0b72
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
"execution_role_arn",
"job_driver",
"configuration_overrides",
"name",
)

template_fields_renderers = {
Expand Down Expand Up @@ -1226,7 +1227,7 @@ def __init__(
self.configuration_overrides = configuration_overrides
self.wait_for_completion = wait_for_completion
self.config = config or {}
self.name = name or self.config.pop("name", f"emr_serverless_job_airflow_{uuid4()}")
self.name = name
self.waiter_max_attempts = int(waiter_max_attempts) # type: ignore[arg-type]
self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
self.job_id: str | None = None
Expand Down Expand Up @@ -1268,6 +1269,7 @@ def execute(self, context: Context, event: dict[str, Any] | None = None) -> str
status_args=["application.state", "application.stateDetails"],
)
self.log.info("Starting job on Application: %s", self.application_id)
self.name = self.name or self.config.pop("name", f"emr_serverless_job_airflow_{uuid4()}")
response = self.hook.conn.start_job_run(
clientToken=self.client_request_token,
applicationId=self.application_id,
Expand Down
21 changes: 11 additions & 10 deletions tests/providers/amazon/aws/operators/test_emr_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,8 @@ def test_job_run_app_started(self, mock_conn, mock_get_waiter):
job_driver=job_driver,
configuration_overrides=configuration_overrides,
)
default_name = operator.name
id = operator.execute(None)
default_name = operator.name

assert operator.wait_for_completion is True
mock_conn.get_application.assert_called_once_with(applicationId=application_id)
Expand Down Expand Up @@ -413,11 +413,12 @@ def test_job_run_job_failed(self, mock_conn, mock_get_waiter):
job_driver=job_driver,
configuration_overrides=configuration_overrides,
)
default_name = operator.name
with pytest.raises(AirflowException) as ex_message:
id = operator.execute(None)
assert id == job_run_id
assert "Serverless Job failed:" in str(ex_message.value)
default_name = operator.name

mock_conn.get_application.assert_called_once_with(applicationId=application_id)
mock_conn.start_job_run.assert_called_once_with(
clientToken=client_request_token,
Expand Down Expand Up @@ -446,9 +447,8 @@ def test_job_run_app_not_started(self, mock_conn, mock_get_waiter):
job_driver=job_driver,
configuration_overrides=configuration_overrides,
)
default_name = operator.name

id = operator.execute(None)
default_name = operator.name

assert operator.wait_for_completion is True
mock_conn.get_application.assert_called_once_with(applicationId=application_id)
Expand Down Expand Up @@ -516,8 +516,8 @@ def test_job_run_app_not_started_no_wait_for_completion(self, mock_conn, mock_ge
configuration_overrides=configuration_overrides,
wait_for_completion=False,
)
default_name = operator.name
id = operator.execute(None)
default_name = operator.name

mock_conn.get_application.assert_called_once_with(applicationId=application_id)
mock_get_waiter().wait.assert_called_once()
Expand Down Expand Up @@ -550,9 +550,10 @@ def test_job_run_app_started_no_wait_for_completion(self, mock_conn, mock_get_wa
configuration_overrides=configuration_overrides,
wait_for_completion=False,
)
default_name = operator.name
id = operator.execute(None)
assert id == job_run_id
default_name = operator.name

mock_conn.start_job_run.assert_called_once_with(
clientToken=client_request_token,
applicationId=application_id,
Expand Down Expand Up @@ -581,11 +582,11 @@ def test_failed_start_job_run(self, mock_conn, mock_get_waiter):
job_driver=job_driver,
configuration_overrides=configuration_overrides,
)
default_name = operator.name
with pytest.raises(AirflowException) as ex_message:
operator.execute(None)

assert "EMR serverless job failed to start:" in str(ex_message.value)
default_name = operator.name

mock_conn.get_application.assert_called_once_with(applicationId=application_id)
mock_get_waiter().wait.assert_called_once()
mock_conn.start_job_run.assert_called_once_with(
Expand Down Expand Up @@ -619,11 +620,11 @@ def test_start_job_run_fail_on_wait_for_completion(self, mock_conn, mock_get_wai
job_driver=job_driver,
configuration_overrides=configuration_overrides,
)
default_name = operator.name
with pytest.raises(AirflowException) as ex_message:
operator.execute(None)

assert "Serverless Job failed:" in str(ex_message.value)
default_name = operator.name

mock_conn.get_application.call_count == 2
mock_conn.start_job_run.assert_called_once_with(
clientToken=client_request_token,
Expand Down

0 comments on commit 03a0b72

Please sign in to comment.