Skip to content

Commit

Permalink
Added retry strategy parameter to Amazon AWS provider Batch Operator …
Browse files Browse the repository at this point in the history
…to allow dynamic Batch retry strategies (apache#35789)
  • Loading branch information
evgenyslab authored Nov 22, 2023
1 parent 72ba63e commit b71c14c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
5 changes: 5 additions & 0 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class BatchOperator(BaseOperator):
"array_properties",
"node_overrides",
"parameters",
"retry_strategy",
"waiters",
"tags",
"wait_for_completion",
Expand All @@ -122,6 +123,7 @@ class BatchOperator(BaseOperator):
"container_overrides": "json",
"parameters": "json",
"node_overrides": "json",
"retry_strategy": "json",
}

@property
Expand Down Expand Up @@ -160,6 +162,7 @@ def __init__(
share_identifier: str | None = None,
scheduling_priority_override: int | None = None,
parameters: dict | None = None,
retry_strategy: dict | None = None,
job_id: str | None = None,
waiters: Any | None = None,
max_retries: int = 4200,
Expand Down Expand Up @@ -201,6 +204,7 @@ def __init__(
self.scheduling_priority_override = scheduling_priority_override
self.array_properties = array_properties
self.parameters = parameters or {}
self.retry_strategy = retry_strategy or {}
self.waiters = waiters
self.tags = tags or {}
self.wait_for_completion = wait_for_completion
Expand Down Expand Up @@ -287,6 +291,7 @@ def submit_job(self, context: Context):
"tags": self.tags,
"containerOverrides": self.container_overrides,
"nodeOverrides": self.node_overrides,
"retryStrategy": self.retry_strategy,
"shareIdentifier": self.share_identifier,
"schedulingPriorityOverride": self.scheduling_priority_override,
}
Expand Down
6 changes: 6 additions & 0 deletions tests/providers/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def setup_method(self, _, get_client_type_mock):
max_retries=self.MAX_RETRIES,
status_retries=self.STATUS_RETRIES,
parameters=None,
retry_strategy=None,
container_overrides={},
array_properties=None,
aws_conn_id="airflow_test",
Expand Down Expand Up @@ -96,6 +97,7 @@ def test_init(self):
assert self.batch.hook.max_retries == self.MAX_RETRIES
assert self.batch.hook.status_retries == self.STATUS_RETRIES
assert self.batch.parameters == {}
assert self.batch.retry_strategy == {}
assert self.batch.container_overrides == {}
assert self.batch.array_properties is None
assert self.batch.node_overrides is None
Expand All @@ -119,6 +121,7 @@ def test_template_fields_overrides(self):
"array_properties",
"node_overrides",
"parameters",
"retry_strategy",
"waiters",
"tags",
"wait_for_completion",
Expand All @@ -143,6 +146,7 @@ def test_execute_without_failures(self, check_mock, wait_mock, job_description_m
containerOverrides={},
jobDefinition="hello-world",
parameters={},
retryStrategy={},
tags={},
)

Expand All @@ -166,6 +170,7 @@ def test_execute_with_failures(self):
containerOverrides={},
jobDefinition="hello-world",
parameters={},
retryStrategy={},
tags={},
)

Expand Down Expand Up @@ -232,6 +237,7 @@ def test_override_not_sent_if_not_set(self, client_mock, override):
"jobName": JOB_NAME,
"jobDefinition": "hello-world",
"parameters": {},
"retryStrategy": {},
"tags": {},
}
if override == "overrides":
Expand Down

0 comments on commit b71c14c

Please sign in to comment.