Skip to content

Commit

Permalink
Implement on_status parameter for KubernetesDeleteJobOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov committed Mar 25, 2024
1 parent 9c4e333 commit 38e7e18
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 25 deletions.
43 changes: 28 additions & 15 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,18 +568,19 @@ def is_job_complete(self, job: V1Job) -> bool:
:return: Boolean indicating that the given job is complete.
"""
if conditions := job.status.conditions:
if final_condition_types := list(
c for c in conditions if c.type in JOB_FINAL_STATUS_CONDITION_TYPES and c.status
):
s = "s" if len(final_condition_types) > 1 else ""
self.log.info(
"The job '%s' state%s: %s",
job.metadata.name,
s,
", ".join(f"{c.type} at {c.last_transition_time}" for c in final_condition_types),
)
return True
if status := job.status:
if conditions := status.conditions:
if final_condition_types := list(
c for c in conditions if c.type in JOB_FINAL_STATUS_CONDITION_TYPES and c.status
):
s = "s" if len(final_condition_types) > 1 else ""
self.log.info(
"The job '%s' state%s: %s",
job.metadata.name,
s,
", ".join(f"{c.type} at {c.last_transition_time}" for c in final_condition_types),
)
return True
return False

@staticmethod
Expand All @@ -588,9 +589,21 @@ def is_job_failed(job: V1Job) -> str | bool:
:return: Error message if the job is failed, and False otherwise.
"""
conditions = job.status.conditions or []
if fail_condition := next((c for c in conditions if c.type == "Failed" and c.status), None):
return fail_condition.reason
if status := job.status:
conditions = status.conditions or []
if fail_condition := next((c for c in conditions if c.type == "Failed" and c.status), None):
return fail_condition.reason
return False

@staticmethod
def is_job_successful(job: V1Job) -> str | bool:
"""Check whether the given job is completed successfully..
:return: Error message if the job is failed, and False otherwise.
"""
if status := job.status:
conditions = status.conditions or []
return bool(next((c for c in conditions if c.type == "Complete" and c.status), None))
return False

def patch_namespaced_job(self, job_name: str, namespace: str, body: object) -> V1Job:
Expand Down
45 changes: 42 additions & 3 deletions airflow/providers/cncf/kubernetes/operators/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,17 @@ class KubernetesDeleteJobOperator(BaseOperator):
:param in_cluster: run kubernetes client with in_cluster configuration.
:param cluster_context: context that points to kubernetes cluster.
Ignored when in_cluster is True. If None, current-context is used. (templated)
:param on_status: Condition for performing delete operation depending on the job status. Values:
``None`` - delete the job regardless of its status, "Complete" - delete only successfully completed
jobs, "Failed" - delete only failed jobs. (default: ``None``)
:param wait_until_job_complete: Whether to wait for the job to complete. (default: ``False``)
:param job_poll_interval: Interval in seconds between polling the job status. Used when the `on_status`
parameter is set. (default: 10.0)
"""

template_fields: Sequence[str] = (
"config_file",
"name",
"namespace",
"cluster_context",
)
Expand All @@ -383,6 +390,9 @@ def __init__(
config_file: str | None = None,
in_cluster: bool | None = None,
cluster_context: str | None = None,
on_status: str | None = None,
wait_until_job_complete: bool = False,
job_poll_interval: float = 10.0,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -392,6 +402,9 @@ def __init__(
self.config_file = config_file
self.in_cluster = in_cluster
self.cluster_context = cluster_context
self.on_status = on_status
self.wait_until_job_complete = wait_until_job_complete
self.job_poll_interval = job_poll_interval

@cached_property
def hook(self) -> KubernetesHook:
Expand All @@ -408,9 +421,34 @@ def client(self) -> BatchV1Api:

def execute(self, context: Context):
try:
self.log.info("Deleting kubernetes Job: %s", self.name)
self.client.delete_namespaced_job(name=self.name, namespace=self.namespace)
self.log.info("Kubernetes job was deleted.")
if self.on_status not in ("Complete", "Failed", None):
raise AirflowException(
"The `on_status` parameter must be one of 'Complete', 'Failed' or None. "
"The current value is %s",
str(self.on_status),
)

if self.wait_until_job_complete:
job = self.hook.wait_until_job_complete(
job_name=self.name, namespace=self.namespace, job_poll_interval=self.job_poll_interval
)
else:
job = self.hook.get_job_status(job_name=self.name, namespace=self.namespace)

if (
self.on_status is None
or (self.on_status == "Complete" and self.hook.is_job_successful(job=job))
or (self.on_status == "Failed" and self.hook.is_job_failed(job=job))
):
self.log.info("Deleting kubernetes Job: %s", self.name)
self.client.delete_namespaced_job(name=self.name, namespace=self.namespace)
self.log.info("Kubernetes job was deleted.")
else:
self.log.info(
"Deletion of the job %s was skipped due to settings of on_status=%s",
self.name,
self.on_status,
)
except ApiException as e:
if e.status == 404:
self.log.info("The Kubernetes job %s does not exist.", self.name)
Expand Down Expand Up @@ -442,6 +480,7 @@ class KubernetesPatchJobOperator(BaseOperator):

template_fields: Sequence[str] = (
"config_file",
"name",
"namespace",
"body",
"cluster_context",
Expand Down
57 changes: 57 additions & 0 deletions tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,52 @@ def test_is_job_failed(self, mock_merger, mock_loader, conditions, expected_resu

assert actual_result == expected_result

@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
def test_is_job_failed_no_status(self, mock_merger, mock_loader):
mock_job = mock.MagicMock()
mock_job.status = None

hook = KubernetesHook()
job_failed = hook.is_job_failed(mock_job)

assert not job_failed

@pytest.mark.parametrize(
"condition_type, status, expected_result",
[
("Complete", False, False),
("Complete", True, True),
("Failed", False, False),
("Failed", True, False),
("Suspended", False, False),
("Suspended", True, False),
("Unknown", False, False),
("Unknown", True, False),
],
)
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
def test_is_job_successful(self, mock_merger, mock_loader, condition_type, status, expected_result):
mock_job = mock.MagicMock()
mock_job.status.conditions = [mock.MagicMock(type=condition_type, status=status)]

hook = KubernetesHook()
actual_result = hook.is_job_successful(mock_job)

assert actual_result == expected_result

@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
def test_is_job_successful_no_status(self, mock_merger, mock_loader):
mock_job = mock.MagicMock()
mock_job.status = None

hook = KubernetesHook()
job_successful = hook.is_job_successful(mock_job)

assert not job_successful

@pytest.mark.parametrize(
"condition_type, status, expected_result",
[
Expand All @@ -521,6 +567,17 @@ def test_is_job_complete(self, mock_merger, mock_loader, condition_type, status,

assert actual_result == expected_result

@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
def test_is_job_complete_no_status(self, mock_merger, mock_loader):
mock_job = mock.MagicMock()
mock_job.status = None

hook = KubernetesHook()
job_complete = hook.is_job_complete(mock_job)

assert not job_complete

@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
@patch(f"{HOOK_MODULE}.KubernetesHook.get_job_status")
Expand Down
112 changes: 108 additions & 4 deletions tests/providers/cncf/kubernetes/operators/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
# under the License.
from __future__ import annotations

import random
import re
import string
from unittest import mock
from unittest.mock import patch

Expand All @@ -41,6 +43,7 @@
POLL_INTERVAL = 100
JOB_NAME = "test-job"
JOB_NAMESPACE = "test-namespace"
JOB_POLL_INTERVAL = 20.0
KUBERNETES_CONN_ID = "test-conn_id"


Expand Down Expand Up @@ -694,19 +697,120 @@ def setup_tests(self):

patch.stopall()

@patch(f"{HOOK_CLASS}.get_job_status")
@patch(f"{HOOK_CLASS}.wait_until_job_complete")
@patch("kubernetes.config.load_kube_config")
@patch("kubernetes.client.api.BatchV1Api.delete_namespaced_job")
def test_delete_execute(self, mock_delete_namespaced_job, mock_load_kube_config):
def test_execute(
self,
mock_delete_namespaced_job,
mock_load_kube_config,
mock_wait_until_job_complete,
mock_get_job_status,
):
op = KubernetesDeleteJobOperator(
kubernetes_conn_id="kubernetes_default",
task_id="test_delete_job",
name="test_job_name",
namespace="test_job_namespace",
name=JOB_NAME,
namespace=JOB_NAMESPACE,
)

op.execute(None)

mock_delete_namespaced_job.assert_called()
assert not mock_wait_until_job_complete.called
mock_get_job_status.assert_called_once_with(job_name=JOB_NAME, namespace=JOB_NAMESPACE)
mock_delete_namespaced_job.assert_called_once_with(name=JOB_NAME, namespace=JOB_NAMESPACE)

@patch(f"{HOOK_CLASS}.get_job_status")
@patch(f"{HOOK_CLASS}.wait_until_job_complete")
@patch("kubernetes.config.load_kube_config")
@patch("kubernetes.client.api.BatchV1Api.delete_namespaced_job")
def test_execute_wait_until_job_complete_true(
self,
mock_delete_namespaced_job,
mock_load_kube_config,
mock_wait_until_job_complete,
mock_get_job_status,
):
op = KubernetesDeleteJobOperator(
kubernetes_conn_id="kubernetes_default",
task_id="test_delete_job",
name=JOB_NAME,
namespace=JOB_NAMESPACE,
wait_until_job_complete=True,
job_poll_interval=JOB_POLL_INTERVAL,
)

op.execute({})

mock_wait_until_job_complete.assert_called_once_with(
job_name=JOB_NAME, namespace=JOB_NAMESPACE, job_poll_interval=JOB_POLL_INTERVAL
)
assert not mock_get_job_status.called
mock_delete_namespaced_job.assert_called_once_with(name=JOB_NAME, namespace=JOB_NAMESPACE)

@pytest.mark.parametrize(
"on_status, success, fail, deleted",
[
(None, True, True, True),
(None, True, False, True),
(None, False, True, True),
(None, False, False, True),
("Complete", True, True, True),
("Complete", True, False, True),
("Complete", False, True, False),
("Complete", False, False, False),
("Failed", True, True, True),
("Failed", True, False, False),
("Failed", False, True, True),
("Failed", False, False, False),
],
)
@patch(f"{HOOK_CLASS}.is_job_failed")
@patch(f"{HOOK_CLASS}.is_job_successful")
@patch("kubernetes.config.load_kube_config")
@patch("kubernetes.client.api.BatchV1Api.delete_namespaced_job")
def test_execute_on_status(
self,
mock_delete_namespaced_job,
mock_load_kube_config,
mock_is_job_successful,
mock_is_job_failed,
on_status,
success,
fail,
deleted,
):
mock_is_job_successful.return_value = success
mock_is_job_failed.return_value = fail

op = KubernetesDeleteJobOperator(
kubernetes_conn_id="kubernetes_default",
task_id="test_delete_job",
name=JOB_NAME,
namespace=JOB_NAMESPACE,
on_status=on_status,
)

op.execute({})

assert mock_delete_namespaced_job.called == deleted

def test_execute_on_status_exception(self):
invalid_on_status = "".join(
random.choices(string.ascii_letters + string.digits, k=random.randint(1, 16))
)

op = KubernetesDeleteJobOperator(
kubernetes_conn_id="kubernetes_default",
task_id="test_delete_job",
name=JOB_NAME,
namespace=JOB_NAMESPACE,
on_status=invalid_on_status,
)

with pytest.raises(AirflowException):
op.execute({})


@pytest.mark.execution_timeout(300)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
update_job = KubernetesPatchJobOperator(
task_id="update-job-task",
namespace="default",
name=JOB_NAME,
name=k8s_job.output["job_name"],
body={"spec": {"suspend": False}},
)
# [END howto_operator_update_job]
Expand All @@ -77,14 +77,17 @@
# [START howto_operator_delete_k8s_job]
delete_job_task = KubernetesDeleteJobOperator(
task_id="delete_job_task",
name=JOB_NAME,
name=k8s_job.output["job_name"],
namespace=JOB_NAMESPACE,
wait_until_job_complete=True,
on_status="Complete",
job_poll_interval=1.0,
)
# [END howto_operator_delete_k8s_job]

delete_job_task_def = KubernetesDeleteJobOperator(
task_id="delete_job_task_def",
name=JOB_NAME + "-def",
name=k8s_job_def.output["job_name"],
namespace=JOB_NAMESPACE,
)

Expand Down

0 comments on commit 38e7e18

Please sign in to comment.