diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index ec67254e3c989..1b3c4254e8e0e 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -568,18 +568,19 @@ def is_job_complete(self, job: V1Job) -> bool: :return: Boolean indicating that the given job is complete. """ - if conditions := job.status.conditions: - if final_condition_types := list( - c for c in conditions if c.type in JOB_FINAL_STATUS_CONDITION_TYPES and c.status - ): - s = "s" if len(final_condition_types) > 1 else "" - self.log.info( - "The job '%s' state%s: %s", - job.metadata.name, - s, - ", ".join(f"{c.type} at {c.last_transition_time}" for c in final_condition_types), - ) - return True + if status := job.status: + if conditions := status.conditions: + if final_condition_types := list( + c for c in conditions if c.type in JOB_FINAL_STATUS_CONDITION_TYPES and c.status + ): + s = "s" if len(final_condition_types) > 1 else "" + self.log.info( + "The job '%s' state%s: %s", + job.metadata.name, + s, + ", ".join(f"{c.type} at {c.last_transition_time}" for c in final_condition_types), + ) + return True return False @staticmethod @@ -588,9 +589,21 @@ def is_job_failed(job: V1Job) -> str | bool: :return: Error message if the job is failed, and False otherwise. """ - conditions = job.status.conditions or [] - if fail_condition := next((c for c in conditions if c.type == "Failed" and c.status), None): - return fail_condition.reason + if status := job.status: + conditions = status.conditions or [] + if fail_condition := next((c for c in conditions if c.type == "Failed" and c.status), None): + return fail_condition.reason + return False + + @staticmethod + def is_job_successful(job: V1Job) -> str | bool: + """Check whether the given job is completed successfully.. + + :return: Error message if the job is failed, and False otherwise. + """ + if status := job.status: + conditions = status.conditions or [] + return bool(next((c for c in conditions if c.type == "Complete" and c.status), None)) return False def patch_namespaced_job(self, job_name: str, namespace: str, body: object) -> V1Job: diff --git a/airflow/providers/cncf/kubernetes/operators/job.py b/airflow/providers/cncf/kubernetes/operators/job.py index 41d260bc98483..95a912344d18d 100644 --- a/airflow/providers/cncf/kubernetes/operators/job.py +++ b/airflow/providers/cncf/kubernetes/operators/job.py @@ -366,10 +366,17 @@ class KubernetesDeleteJobOperator(BaseOperator): :param in_cluster: run kubernetes client with in_cluster configuration. :param cluster_context: context that points to kubernetes cluster. Ignored when in_cluster is True. If None, current-context is used. (templated) + :param on_status: Condition for performing delete operation depending on the job status. Values: + ``None`` - delete the job regardless of its status, "Complete" - delete only successfully completed + jobs, "Failed" - delete only failed jobs. (default: ``None``) + :param wait_until_job_complete: Whether to wait for the job to complete. (default: ``False``) + :param job_poll_interval: Interval in seconds between polling the job status. Used when the `on_status` + parameter is set. (default: 10.0) """ template_fields: Sequence[str] = ( "config_file", + "name", "namespace", "cluster_context", ) @@ -383,6 +390,9 @@ def __init__( config_file: str | None = None, in_cluster: bool | None = None, cluster_context: str | None = None, + on_status: str | None = None, + wait_until_job_complete: bool = False, + job_poll_interval: float = 10.0, **kwargs, ) -> None: super().__init__(**kwargs) @@ -392,6 +402,9 @@ def __init__( self.config_file = config_file self.in_cluster = in_cluster self.cluster_context = cluster_context + self.on_status = on_status + self.wait_until_job_complete = wait_until_job_complete + self.job_poll_interval = job_poll_interval @cached_property def hook(self) -> KubernetesHook: @@ -408,9 +421,34 @@ def client(self) -> BatchV1Api: def execute(self, context: Context): try: - self.log.info("Deleting kubernetes Job: %s", self.name) - self.client.delete_namespaced_job(name=self.name, namespace=self.namespace) - self.log.info("Kubernetes job was deleted.") + if self.on_status not in ("Complete", "Failed", None): + raise AirflowException( + "The `on_status` parameter must be one of 'Complete', 'Failed' or None. " + "The current value is %s", + str(self.on_status), + ) + + if self.wait_until_job_complete: + job = self.hook.wait_until_job_complete( + job_name=self.name, namespace=self.namespace, job_poll_interval=self.job_poll_interval + ) + else: + job = self.hook.get_job_status(job_name=self.name, namespace=self.namespace) + + if ( + self.on_status is None + or (self.on_status == "Complete" and self.hook.is_job_successful(job=job)) + or (self.on_status == "Failed" and self.hook.is_job_failed(job=job)) + ): + self.log.info("Deleting kubernetes Job: %s", self.name) + self.client.delete_namespaced_job(name=self.name, namespace=self.namespace) + self.log.info("Kubernetes job was deleted.") + else: + self.log.info( + "Deletion of the job %s was skipped due to settings of on_status=%s", + self.name, + self.on_status, + ) except ApiException as e: if e.status == 404: self.log.info("The Kubernetes job %s does not exist.", self.name) @@ -442,6 +480,7 @@ class KubernetesPatchJobOperator(BaseOperator): template_fields: Sequence[str] = ( "config_file", + "name", "namespace", "body", "cluster_context", diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index 197aef4d822f5..5f6af754b11af 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -497,6 +497,52 @@ def test_is_job_failed(self, mock_merger, mock_loader, conditions, expected_resu assert actual_result == expected_result + @patch("kubernetes.config.kube_config.KubeConfigLoader") + @patch("kubernetes.config.kube_config.KubeConfigMerger") + def test_is_job_failed_no_status(self, mock_merger, mock_loader): + mock_job = mock.MagicMock() + mock_job.status = None + + hook = KubernetesHook() + job_failed = hook.is_job_failed(mock_job) + + assert not job_failed + + @pytest.mark.parametrize( + "condition_type, status, expected_result", + [ + ("Complete", False, False), + ("Complete", True, True), + ("Failed", False, False), + ("Failed", True, False), + ("Suspended", False, False), + ("Suspended", True, False), + ("Unknown", False, False), + ("Unknown", True, False), + ], + ) + @patch("kubernetes.config.kube_config.KubeConfigLoader") + @patch("kubernetes.config.kube_config.KubeConfigMerger") + def test_is_job_successful(self, mock_merger, mock_loader, condition_type, status, expected_result): + mock_job = mock.MagicMock() + mock_job.status.conditions = [mock.MagicMock(type=condition_type, status=status)] + + hook = KubernetesHook() + actual_result = hook.is_job_successful(mock_job) + + assert actual_result == expected_result + + @patch("kubernetes.config.kube_config.KubeConfigLoader") + @patch("kubernetes.config.kube_config.KubeConfigMerger") + def test_is_job_successful_no_status(self, mock_merger, mock_loader): + mock_job = mock.MagicMock() + mock_job.status = None + + hook = KubernetesHook() + job_successful = hook.is_job_successful(mock_job) + + assert not job_successful + @pytest.mark.parametrize( "condition_type, status, expected_result", [ @@ -521,6 +567,17 @@ def test_is_job_complete(self, mock_merger, mock_loader, condition_type, status, assert actual_result == expected_result + @patch("kubernetes.config.kube_config.KubeConfigLoader") + @patch("kubernetes.config.kube_config.KubeConfigMerger") + def test_is_job_complete_no_status(self, mock_merger, mock_loader): + mock_job = mock.MagicMock() + mock_job.status = None + + hook = KubernetesHook() + job_complete = hook.is_job_complete(mock_job) + + assert not job_complete + @patch("kubernetes.config.kube_config.KubeConfigLoader") @patch("kubernetes.config.kube_config.KubeConfigMerger") @patch(f"{HOOK_MODULE}.KubernetesHook.get_job_status") diff --git a/tests/providers/cncf/kubernetes/operators/test_job.py b/tests/providers/cncf/kubernetes/operators/test_job.py index 803523429ce5c..23a947670253a 100644 --- a/tests/providers/cncf/kubernetes/operators/test_job.py +++ b/tests/providers/cncf/kubernetes/operators/test_job.py @@ -16,7 +16,9 @@ # under the License. from __future__ import annotations +import random import re +import string from unittest import mock from unittest.mock import patch @@ -41,6 +43,7 @@ POLL_INTERVAL = 100 JOB_NAME = "test-job" JOB_NAMESPACE = "test-namespace" +JOB_POLL_INTERVAL = 20.0 KUBERNETES_CONN_ID = "test-conn_id" @@ -694,19 +697,120 @@ def setup_tests(self): patch.stopall() + @patch(f"{HOOK_CLASS}.get_job_status") + @patch(f"{HOOK_CLASS}.wait_until_job_complete") @patch("kubernetes.config.load_kube_config") @patch("kubernetes.client.api.BatchV1Api.delete_namespaced_job") - def test_delete_execute(self, mock_delete_namespaced_job, mock_load_kube_config): + def test_execute( + self, + mock_delete_namespaced_job, + mock_load_kube_config, + mock_wait_until_job_complete, + mock_get_job_status, + ): op = KubernetesDeleteJobOperator( kubernetes_conn_id="kubernetes_default", task_id="test_delete_job", - name="test_job_name", - namespace="test_job_namespace", + name=JOB_NAME, + namespace=JOB_NAMESPACE, ) op.execute(None) - mock_delete_namespaced_job.assert_called() + assert not mock_wait_until_job_complete.called + mock_get_job_status.assert_called_once_with(job_name=JOB_NAME, namespace=JOB_NAMESPACE) + mock_delete_namespaced_job.assert_called_once_with(name=JOB_NAME, namespace=JOB_NAMESPACE) + + @patch(f"{HOOK_CLASS}.get_job_status") + @patch(f"{HOOK_CLASS}.wait_until_job_complete") + @patch("kubernetes.config.load_kube_config") + @patch("kubernetes.client.api.BatchV1Api.delete_namespaced_job") + def test_execute_wait_until_job_complete_true( + self, + mock_delete_namespaced_job, + mock_load_kube_config, + mock_wait_until_job_complete, + mock_get_job_status, + ): + op = KubernetesDeleteJobOperator( + kubernetes_conn_id="kubernetes_default", + task_id="test_delete_job", + name=JOB_NAME, + namespace=JOB_NAMESPACE, + wait_until_job_complete=True, + job_poll_interval=JOB_POLL_INTERVAL, + ) + + op.execute({}) + + mock_wait_until_job_complete.assert_called_once_with( + job_name=JOB_NAME, namespace=JOB_NAMESPACE, job_poll_interval=JOB_POLL_INTERVAL + ) + assert not mock_get_job_status.called + mock_delete_namespaced_job.assert_called_once_with(name=JOB_NAME, namespace=JOB_NAMESPACE) + + @pytest.mark.parametrize( + "on_status, success, fail, deleted", + [ + (None, True, True, True), + (None, True, False, True), + (None, False, True, True), + (None, False, False, True), + ("Complete", True, True, True), + ("Complete", True, False, True), + ("Complete", False, True, False), + ("Complete", False, False, False), + ("Failed", True, True, True), + ("Failed", True, False, False), + ("Failed", False, True, True), + ("Failed", False, False, False), + ], + ) + @patch(f"{HOOK_CLASS}.is_job_failed") + @patch(f"{HOOK_CLASS}.is_job_successful") + @patch("kubernetes.config.load_kube_config") + @patch("kubernetes.client.api.BatchV1Api.delete_namespaced_job") + def test_execute_on_status( + self, + mock_delete_namespaced_job, + mock_load_kube_config, + mock_is_job_successful, + mock_is_job_failed, + on_status, + success, + fail, + deleted, + ): + mock_is_job_successful.return_value = success + mock_is_job_failed.return_value = fail + + op = KubernetesDeleteJobOperator( + kubernetes_conn_id="kubernetes_default", + task_id="test_delete_job", + name=JOB_NAME, + namespace=JOB_NAMESPACE, + on_status=on_status, + ) + + op.execute({}) + + assert mock_delete_namespaced_job.called == deleted + + def test_execute_on_status_exception(self): + invalid_on_status = "".join( + random.choices(string.ascii_letters + string.digits, k=random.randint(1, 16)) + ) + + op = KubernetesDeleteJobOperator( + kubernetes_conn_id="kubernetes_default", + task_id="test_delete_job", + name=JOB_NAME, + namespace=JOB_NAMESPACE, + on_status=invalid_on_status, + ) + + with pytest.raises(AirflowException): + op.execute({}) @pytest.mark.execution_timeout(300) diff --git a/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py b/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py index 0f17f57a15414..a48763eae8c6f 100644 --- a/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py +++ b/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py @@ -57,7 +57,7 @@ update_job = KubernetesPatchJobOperator( task_id="update-job-task", namespace="default", - name=JOB_NAME, + name=k8s_job.output["job_name"], body={"spec": {"suspend": False}}, ) # [END howto_operator_update_job] @@ -77,14 +77,17 @@ # [START howto_operator_delete_k8s_job] delete_job_task = KubernetesDeleteJobOperator( task_id="delete_job_task", - name=JOB_NAME, + name=k8s_job.output["job_name"], namespace=JOB_NAMESPACE, + wait_until_job_complete=True, + on_status="Complete", + job_poll_interval=1.0, ) # [END howto_operator_delete_k8s_job] delete_job_task_def = KubernetesDeleteJobOperator( task_id="delete_job_task_def", - name=JOB_NAME + "-def", + name=k8s_job_def.output["job_name"], namespace=JOB_NAMESPACE, )