Skip to content

Commit

Permalink
Handle task adoption for batch executor (apache#39590)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored May 14, 2024
1 parent 1489cf7 commit 339ea50
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
50 changes: 47 additions & 3 deletions airflow/providers/amazon/aws/executors/batch/batch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

from __future__ import annotations

import contextlib
import time
from collections import defaultdict, deque
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Sequence

from botocore.exceptions import ClientError, NoCredentialsError

Expand All @@ -34,11 +35,12 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.providers.amazon.aws.executors.batch.boto_schema import (
BatchDescribeJobsResponseSchema,
BatchSubmitJobResponseSchema,
Expand Down Expand Up @@ -306,14 +308,20 @@ def attempt_submit_jobs(self):
self.pending_jobs.append(batch_job)
else:
# Success case
job_id = submit_job_response["job_id"]
self.active_workers.add_job(
job_id=submit_job_response["job_id"],
job_id=job_id,
airflow_task_key=key,
airflow_cmd=cmd,
queue=queue,
exec_config=exec_config,
attempt_number=attempt_number,
)
with contextlib.suppress(AttributeError):
# TODO: Remove this when min_airflow_version is 2.10.0 or higher in Amazon provider.
# running_state is added in Airflow 2.10 and only needed to support task adoption
# (an optional executor feature).
self.running_state(key, job_id)
if failure_reasons:
self.log.error(
"Pending Batch jobs failed to launch for the following reasons: %s. Retrying later.",
Expand Down Expand Up @@ -418,3 +426,39 @@ def _load_submit_kwargs() -> dict:
" and value should be NULL or empty."
)
return submit_kwargs

def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Adopt task instances which have an external_executor_id (the Batch job ID).
Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
"""
with Stats.timer("batch_executor.adopt_task_instances.duration"):
adopted_tis: list[TaskInstance] = []

if job_ids := [ti.external_executor_id for ti in tis if ti.external_executor_id]:
batch_jobs = self._describe_jobs(job_ids)

for batch_job in batch_jobs:
ti = next(ti for ti in tis if ti.external_executor_id == batch_job.job_id)
self.active_workers.add_job(
job_id=batch_job.job_id,
airflow_task_key=ti.key,
airflow_cmd=ti.command_as_list(),
queue=ti.queue,
exec_config=ti.executor_config,
attempt_number=ti.prev_attempted_tries,
)
adopted_tis.append(ti)

if adopted_tis:
tasks = [f"{task} in state {task.state}" for task in adopted_tis]
task_instance_str = "\n\t".join(tasks)
self.log.info(
"Adopted the following %d tasks from a dead executor:\n\t%s",
len(adopted_tis),
task_instance_str,
)

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis
29 changes: 29 additions & 0 deletions tests/providers/amazon/aws/executors/batch/test_batch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.models import TaskInstance
from airflow.providers.amazon.aws.executors.batch import batch_executor, batch_executor_config
from airflow.providers.amazon.aws.executors.batch.batch_executor import (
AwsBatchExecutor,
Expand Down Expand Up @@ -615,6 +616,34 @@ def _mock_sync(
}
executor.batch.describe_jobs.return_value = {"jobs": [after_batch_job]}

def test_try_adopt_task_instances(self, mock_executor):
"""Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event."""
mock_executor.batch.describe_jobs.return_value = {
"jobs": [
{"jobId": "001", "status": "SUCCEEDED"},
{"jobId": "002", "status": "SUCCEEDED"},
],
}

orphaned_tasks = [
mock.Mock(spec=TaskInstance),
mock.Mock(spec=TaskInstance),
mock.Mock(spec=TaskInstance),
]
orphaned_tasks[0].external_executor_id = "001" # Matches a running task_arn
orphaned_tasks[1].external_executor_id = "002" # Matches a running task_arn
orphaned_tasks[2].external_executor_id = None # One orphaned task has no external_executor_id
for task in orphaned_tasks:
task.try_number = 1

not_adopted_tasks = mock_executor.try_adopt_task_instances(orphaned_tasks)

mock_executor.batch.describe_jobs.assert_called_once()
# Two of the three tasks should be adopted.
assert len(orphaned_tasks) - 1 == len(mock_executor.active_workers)
# The remaining one task is unable to be adopted.
assert 1 == len(not_adopted_tasks)


class TestBatchExecutorConfig:
@staticmethod
Expand Down

0 comments on commit 339ea50

Please sign in to comment.