diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index ffa484a9db0ab..ec67254e3c989 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import asyncio import contextlib import json import tempfile @@ -582,13 +583,15 @@ def is_job_complete(self, job: V1Job) -> bool: return False @staticmethod - def is_job_failed(job: V1Job) -> bool: + def is_job_failed(job: V1Job) -> str | bool: """Check whether the given job is failed. - :return: Boolean indicating that the given job is failed. + :return: Error message if the job is failed, and False otherwise. """ conditions = job.status.conditions or [] - return bool(next((c for c in conditions if c.type == "Failed" and c.status), None)) + if fail_condition := next((c for c in conditions if c.type == "Failed" and c.status), None): + return fail_condition.reason + return False def patch_namespaced_job(self, job_name: str, namespace: str, body: object) -> V1Job: """ @@ -769,3 +772,34 @@ async def read_logs(self, name: str, namespace: str): except HTTPError: self.log.exception("There was an error reading the kubernetes API.") raise + + async def get_job_status(self, name: str, namespace: str) -> V1Job: + """ + Get job's status object. + + :param name: Name of the pod. + :param namespace: Name of the pod's namespace. + """ + async with self.get_conn() as connection: + v1_api = async_client.BatchV1Api(connection) + job: V1Job = await v1_api.read_namespaced_job_status( + name=name, + namespace=namespace, + ) + return job + + async def wait_until_job_complete(self, name: str, namespace: str, poll_interval: float = 10) -> V1Job: + """Block job of specified name and namespace until it is complete or failed. + + :param name: Name of Job to fetch. + :param namespace: Namespace of the Job. + :param poll_interval: Interval in seconds between polling the job status + :return: Job object + """ + while True: + self.log.info("Requesting status for the job '%s' ", name) + job: V1Job = await self.get_job_status(name=name, namespace=namespace) + if self.is_job_complete(job=job): + return job + self.log.info("The job '%s' is incomplete. Sleeping for %i sec.", name, poll_interval) + await asyncio.sleep(poll_interval) diff --git a/airflow/providers/cncf/kubernetes/operators/job.py b/airflow/providers/cncf/kubernetes/operators/job.py index d487a789e25ed..41d260bc98483 100644 --- a/airflow/providers/cncf/kubernetes/operators/job.py +++ b/airflow/providers/cncf/kubernetes/operators/job.py @@ -28,6 +28,7 @@ from kubernetes.client.api_client import ApiClient from kubernetes.client.rest import ApiException +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook @@ -37,6 +38,7 @@ ) from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, merge_objects +from airflow.providers.cncf.kubernetes.triggers.job import KubernetesJobTrigger from airflow.utils import yaml from airflow.utils.context import Context @@ -74,6 +76,8 @@ class KubernetesJobOperator(KubernetesPodOperator): Failed). Default is False. :param job_poll_interval: Interval in seconds between polling the job status. Default is 10. Used if the parameter `wait_until_job_complete` set True. + :param deferrable: Run operator in the deferrable mode. Note that the parameter + `wait_until_job_complete` must be set True. """ template_fields: Sequence[str] = tuple({"job_template_file"} | set(KubernetesPodOperator.template_fields)) @@ -93,6 +97,7 @@ def __init__( ttl_seconds_after_finished: int | None = None, wait_until_job_complete: bool = False, job_poll_interval: float = 10, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -110,6 +115,7 @@ def __init__( self.ttl_seconds_after_finished = ttl_seconds_after_finished self.wait_until_job_complete = wait_until_job_complete self.job_poll_interval = job_poll_interval + self.deferrable = deferrable @cached_property def _incluster_namespace(self): @@ -139,6 +145,11 @@ def create_job(self, job_request_obj: k8s.V1Job) -> k8s.V1Job: return job_request_obj def execute(self, context: Context): + if self.deferrable and not self.wait_until_job_complete: + self.log.warning( + "Deferrable mode is available only with parameter `wait_until_job_complete=True`. " + "Please, set it up." + ) self.job_request_obj = self.build_job_request_obj(context) self.job = self.create_job( # must set `self.job` for `on_kill` job_request_obj=self.job_request_obj @@ -148,17 +159,43 @@ def execute(self, context: Context): ti.xcom_push(key="job_name", value=self.job.metadata.name) ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace) + if self.wait_until_job_complete and self.deferrable: + self.execute_deferrable() + return + if self.wait_until_job_complete: self.job = self.hook.wait_until_job_complete( job_name=self.job.metadata.name, namespace=self.job.metadata.namespace, job_poll_interval=self.job_poll_interval, ) - ti.xcom_push( - key="job", value=self.hook.batch_v1_client.api_client.sanitize_for_serialization(self.job) + + ti.xcom_push(key="job", value=self.job.to_dict()) + if self.wait_until_job_complete: + if error_message := self.hook.is_job_failed(job=self.job): + raise AirflowException( + f"Kubernetes job '{self.job.metadata.name}' is failed with error '{error_message}'" + ) + + def execute_deferrable(self): + self.defer( + trigger=KubernetesJobTrigger( + job_name=self.job.metadata.name, # type: ignore[union-attr] + job_namespace=self.job.metadata.namespace, # type: ignore[union-attr] + kubernetes_conn_id=self.kubernetes_conn_id, + cluster_context=self.cluster_context, + config_file=self.config_file, + in_cluster=self.in_cluster, + poll_interval=self.job_poll_interval, + ), + method_name="execute_complete", ) - if self.hook.is_job_failed(job=self.job): - raise AirflowException(f"Kubernetes job '{self.job.metadata.name}' is failed") + + def execute_complete(self, context: Context, event: dict, **kwargs): + ti = context["ti"] + ti.xcom_push(key="job", value=event["job"]) + if event["status"] == "error": + raise AirflowException(event["message"]) @staticmethod def deserialize_job_template_file(path: str) -> k8s.V1Job: @@ -188,6 +225,7 @@ def on_kill(self) -> None: kwargs = { "name": job.metadata.name, "namespace": job.metadata.namespace, + "job": self.hook.batch_v1_client.api_client.sanitize_for_serialization(self.job), } if self.termination_grace_period is not None: kwargs.update(grace_period_seconds=self.termination_grace_period) diff --git a/airflow/providers/cncf/kubernetes/provider.yaml b/airflow/providers/cncf/kubernetes/provider.yaml index fa4644403d2ee..84fdd652db427 100644 --- a/airflow/providers/cncf/kubernetes/provider.yaml +++ b/airflow/providers/cncf/kubernetes/provider.yaml @@ -136,6 +136,7 @@ triggers: - integration-name: Kubernetes python-modules: - airflow.providers.cncf.kubernetes.triggers.pod + - airflow.providers.cncf.kubernetes.triggers.job connection-types: diff --git a/airflow/providers/cncf/kubernetes/triggers/job.py b/airflow/providers/cncf/kubernetes/triggers/job.py new file mode 100644 index 0000000000000..94f4667691153 --- /dev/null +++ b/airflow/providers/cncf/kubernetes/triggers/job.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING, Any, AsyncIterator + +from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + +if TYPE_CHECKING: + from kubernetes.client import V1Job + + +class KubernetesJobTrigger(BaseTrigger): + """ + KubernetesJobTrigger run on the trigger worker to check the state of Job. + + :param job_name: The name of the job. + :param job_namespace: The namespace of the job. + :param kubernetes_conn_id: The :ref:`kubernetes connection id ` + for the Kubernetes cluster. + :param cluster_context: Context that points to kubernetes cluster. + :param config_file: Path to kubeconfig file. + :param poll_interval: Polling period in seconds to check for the status. + :param in_cluster: run kubernetes client with in_cluster configuration. + """ + + def __init__( + self, + job_name: str, + job_namespace: str, + kubernetes_conn_id: str | None = None, + poll_interval: float = 2, + cluster_context: str | None = None, + config_file: str | None = None, + in_cluster: bool | None = None, + ): + super().__init__() + self.job_name = job_name + self.job_namespace = job_namespace + self.kubernetes_conn_id = kubernetes_conn_id + self.poll_interval = poll_interval + self.cluster_context = cluster_context + self.config_file = config_file + self.in_cluster = in_cluster + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize KubernetesCreateJobTrigger arguments and classpath.""" + return ( + "airflow.providers.cncf.kubernetes.triggers.job.KubernetesJobTrigger", + { + "job_name": self.job_name, + "job_namespace": self.job_namespace, + "kubernetes_conn_id": self.kubernetes_conn_id, + "poll_interval": self.poll_interval, + "cluster_context": self.cluster_context, + "config_file": self.config_file, + "in_cluster": self.in_cluster, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] + """Get current job status and yield a TriggerEvent.""" + job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace) + job_dict = job.to_dict() + error_message = self.hook.is_job_failed(job=job) + yield TriggerEvent( + { + "name": job.metadata.name, + "namespace": job.metadata.namespace, + "status": "error" if error_message else "success", + "message": f"Job failed with error: {error_message}" + if error_message + else "Job completed successfully", + "job": job_dict, + } + ) + + @cached_property + def hook(self) -> AsyncKubernetesHook: + return AsyncKubernetesHook( + conn_id=self.kubernetes_conn_id, + in_cluster=self.in_cluster, + config_file=self.config_file, + cluster_context=self.cluster_context, + ) diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst index 3fc7008549c72..80541f315c0d8 100644 --- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst +++ b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst @@ -618,6 +618,15 @@ to ``~/.kube/config``. It also allows users to supply a template YAML file using :start-after: [START howto_operator_k8s_job] :end-before: [END howto_operator_k8s_job] +The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator` also supports deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/cncf/kubernetes/example_kubernetes_job.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_k8s_job_deferrable] + :end-before: [END howto_operator_k8s_job_deferrable] + + Difference between ``KubernetesPodOperator`` and ``KubernetesJobOperator`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator` is operator for creating Job. diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index a1239364e4bb2..197aef4d822f5 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -469,14 +469,20 @@ def test_get_job_status(self, mock_client, mock_kube_config_merger, mock_kube_co ([mock.MagicMock(type="Complete", status=True)], False), ([mock.MagicMock(type="Complete", status=False)], False), ([mock.MagicMock(type="Failed", status=False)], False), - ([mock.MagicMock(type="Failed", status=True)], True), + ([mock.MagicMock(type="Failed", status=True, reason="test reason 1")], "test reason 1"), ( - [mock.MagicMock(type="Complete", status=False), mock.MagicMock(type="Failed", status=True)], - True, + [ + mock.MagicMock(type="Complete", status=False), + mock.MagicMock(type="Failed", status=True, reason="test reason 2"), + ], + "test reason 2", ), ( - [mock.MagicMock(type="Complete", status=True), mock.MagicMock(type="Failed", status=True)], - True, + [ + mock.MagicMock(type="Complete", status=True), + mock.MagicMock(type="Failed", status=True, reason="test reason 3"), + ], + "test reason 3", ), ], ) @@ -584,6 +590,8 @@ class TestAsyncKubernetesHook: INCLUSTER_CONFIG_LOADER = "kubernetes_asyncio.config.incluster_config.InClusterConfigLoader" KUBE_LOADER_CONFIG = "kubernetes_asyncio.config.kube_config.KubeConfigLoader" KUBE_API = "kubernetes_asyncio.client.api.core_v1_api.CoreV1Api.{}" + KUBE_BATCH_API = "kubernetes_asyncio.client.api.batch_v1_api.BatchV1Api.{}" + KUBE_ASYNC_HOOK = HOOK_MODULE + ".AsyncKubernetesHook.{}" @staticmethod def mock_await_result(return_value): @@ -765,3 +773,55 @@ async def test_read_logs(self, lib_method, kube_config_loader, caplog): timestamps=True, ) assert "Container logs from 2023-01-11 Some string logs..." in caplog.text + + @pytest.mark.asyncio + @mock.patch(KUBE_BATCH_API.format("read_namespaced_job_status")) + async def test_get_job_status(self, lib_method, kube_config_loader): + lib_method.return_value = self.mock_await_result(None) + + hook = AsyncKubernetesHook( + conn_id=None, + in_cluster=False, + config_file=None, + cluster_context=None, + ) + await hook.get_job_status( + name=JOB_NAME, + namespace=NAMESPACE, + ) + + lib_method.assert_called_once() + + @pytest.mark.asyncio + @mock.patch(HOOK_MODULE + ".asyncio.sleep") + @mock.patch(KUBE_ASYNC_HOOK.format("is_job_complete")) + @mock.patch(KUBE_ASYNC_HOOK.format("get_job_status")) + async def test_wait_until_job_complete( + self, mock_get_job_status, mock_is_job_complete, mock_sleep, kube_config_loader + ): + mock_job_0, mock_job_1 = mock.MagicMock(), mock.MagicMock() + mock_get_job_status.side_effect = mock.AsyncMock(side_effect=[mock_job_0, mock_job_1]) + mock_is_job_complete.side_effect = [False, True] + + hook = AsyncKubernetesHook( + conn_id=None, + in_cluster=False, + config_file=None, + cluster_context=None, + ) + + job_actual = await hook.wait_until_job_complete( + name=JOB_NAME, + namespace=NAMESPACE, + poll_interval=10, + ) + + mock_get_job_status.assert_has_awaits( + [ + mock.call(name=JOB_NAME, namespace=NAMESPACE), + mock.call(name=JOB_NAME, namespace=NAMESPACE), + ] + ) + mock_is_job_complete.assert_has_calls([mock.call(job=mock_job_0), mock.call(job=mock_job_1)]) + mock_sleep.assert_awaited_once_with(10) + assert job_actual == mock_job_1 diff --git a/tests/providers/cncf/kubernetes/operators/test_job.py b/tests/providers/cncf/kubernetes/operators/test_job.py index a5742abbbb796..803523429ce5c 100644 --- a/tests/providers/cncf/kubernetes/operators/test_job.py +++ b/tests/providers/cncf/kubernetes/operators/test_job.py @@ -39,6 +39,9 @@ JOB_OPERATORS_PATH = "airflow.providers.cncf.kubernetes.operators.job.{}" HOOK_CLASS = JOB_OPERATORS_PATH.format("KubernetesHook") POLL_INTERVAL = 100 +JOB_NAME = "test-job" +JOB_NAMESPACE = "test-namespace" +KUBERNETES_CONN_ID = "test-conn_id" def create_context(task, persist_to_db=False, map_index=None): @@ -480,6 +483,7 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj): [ mock.call(key="job_name", value=mock_job_expected.metadata.name), mock.call(key="job_namespace", value=mock_job_expected.metadata.namespace), + mock.call(key="job", value=mock_job_expected.to_dict.return_value), ] ) @@ -489,19 +493,97 @@ def test_execute(self, mock_hook, mock_create_job, mock_build_job_request_obj): assert execute_result is None assert not mock_hook.wait_until_job_complete.called + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.execute_deferrable")) + @patch(HOOK_CLASS) + def test_execute_in_deferrable( + self, mock_hook, mock_execute_deferrable, mock_create_job, mock_build_job_request_obj + ): + mock_hook.return_value.is_job_failed.return_value = False + mock_job_request_obj = mock_build_job_request_obj.return_value + mock_job_expected = mock_create_job.return_value + mock_ti = mock.MagicMock() + context = dict(ti=mock_ti) + + op = KubernetesJobOperator( + task_id="test_task_id", + wait_until_job_complete=True, + deferrable=True, + ) + actual_result = op.execute(context=context) + + mock_build_job_request_obj.assert_called_once_with(context) + mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj) + mock_ti.xcom_push.assert_has_calls( + [ + mock.call(key="job_name", value=mock_job_expected.metadata.name), + mock.call(key="job_namespace", value=mock_job_expected.metadata.namespace), + ] + ) + mock_execute_deferrable.assert_called_once() + + assert op.job_request_obj == mock_job_request_obj + assert op.job == mock_job_expected + assert actual_result is None + assert not mock_hook.wait_until_job_complete.called + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(HOOK_CLASS) def test_execute_fail(self, mock_hook, mock_create_job, mock_build_job_request_obj): - mock_hook.return_value.is_job_failed.return_value = True + mock_hook.return_value.is_job_failed.return_value = "Error" op = KubernetesJobOperator( task_id="test_task_id", + wait_until_job_complete=True, ) with pytest.raises(AirflowException): op.execute(context=dict(ti=mock.MagicMock())) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.defer")) + @patch(JOB_OPERATORS_PATH.format("KubernetesJobTrigger")) + def test_execute_deferrable(self, mock_trigger, mock_execute_deferrable): + mock_cluster_context = mock.MagicMock() + mock_config_file = mock.MagicMock() + mock_in_cluster = mock.MagicMock() + + mock_job = mock.MagicMock() + mock_job.metadata.name = JOB_NAME + mock_job.metadata.namespace = JOB_NAMESPACE + + mock_trigger_instance = mock_trigger.return_value + + op = KubernetesJobOperator( + task_id="test_task_id", + kubernetes_conn_id=KUBERNETES_CONN_ID, + cluster_context=mock_cluster_context, + config_file=mock_config_file, + in_cluster=mock_in_cluster, + job_poll_interval=POLL_INTERVAL, + wait_until_job_complete=True, + deferrable=True, + ) + op.job = mock_job + + actual_result = op.execute_deferrable() + + mock_execute_deferrable.assert_called_once_with( + trigger=mock_trigger_instance, + method_name="execute_complete", + ) + mock_trigger.assert_called_once_with( + job_name=JOB_NAME, + job_namespace=JOB_NAMESPACE, + kubernetes_conn_id=KUBERNETES_CONN_ID, + cluster_context=mock_cluster_context, + config_file=mock_config_file, + in_cluster=mock_in_cluster, + poll_interval=POLL_INTERVAL, + ) + assert actual_result is None + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj")) @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job")) @patch(f"{HOOK_CLASS}.wait_until_job_complete") @@ -524,6 +606,82 @@ def test_wait_until_job_complete( job_poll_interval=POLL_INTERVAL, ) + def test_execute_complete(self): + mock_ti = mock.MagicMock() + context = {"ti": mock_ti} + mock_job = mock.MagicMock() + event = {"job": mock_job, "status": "success"} + + KubernetesJobOperator(task_id="test_task_id").execute_complete(context=context, event=event) + + mock_ti.xcom_push.assert_called_once_with(key="job", value=mock_job) + + def test_execute_complete_fail(self): + mock_ti = mock.MagicMock() + context = {"ti": mock_ti} + mock_job = mock.MagicMock() + event = {"job": mock_job, "status": "error", "message": "error message"} + + with pytest.raises(AirflowException): + KubernetesJobOperator(task_id="test_task_id").execute_complete(context=context, event=event) + + mock_ti.xcom_push.assert_called_once_with(key="job", value=mock_job) + + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.client")) + @patch(HOOK_CLASS) + def test_on_kill(self, mock_hook, mock_client): + mock_job = mock.MagicMock() + mock_job.metadata.name = JOB_NAME + mock_job.metadata.namespace = JOB_NAMESPACE + mock_serialize = mock_hook.return_value.batch_v1_client.api_client.sanitize_for_serialization + mock_serialized_job = mock_serialize.return_value + + op = KubernetesJobOperator(task_id="test_task_id") + op.job = mock_job + op.on_kill() + + mock_client.delete_namespaced_job.assert_called_once_with( + name=JOB_NAME, + namespace=JOB_NAMESPACE, + job=mock_serialized_job, + ) + mock_serialize.assert_called_once_with(mock_job) + + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.client")) + @patch(HOOK_CLASS) + def test_on_kill_termination_grace_period(self, mock_hook, mock_client): + mock_job = mock.MagicMock() + mock_job.metadata.name = JOB_NAME + mock_job.metadata.namespace = JOB_NAMESPACE + mock_serialize = mock_hook.return_value.batch_v1_client.api_client.sanitize_for_serialization + mock_serialized_job = mock_serialize.return_value + mock_termination_grace_period = mock.MagicMock() + + op = KubernetesJobOperator( + task_id="test_task_id", termination_grace_period=mock_termination_grace_period + ) + op.job = mock_job + op.on_kill() + + mock_client.delete_namespaced_job.assert_called_once_with( + name=JOB_NAME, + namespace=JOB_NAMESPACE, + job=mock_serialized_job, + grace_period_seconds=mock_termination_grace_period, + ) + mock_serialize.assert_called_once_with(mock_job) + + @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.client")) + @patch(HOOK_CLASS) + def test_on_kill_none_job(self, mock_hook, mock_client): + mock_serialize = mock_hook.return_value.batch_v1_client.api_client.sanitize_for_serialization + + op = KubernetesJobOperator(task_id="test_task_id") + op.on_kill() + + mock_client.delete_namespaced_job.assert_not_called() + mock_serialize.assert_not_called() + @pytest.mark.execution_timeout(300) class TestKubernetesDeleteJobOperator: diff --git a/tests/providers/cncf/kubernetes/triggers/test_job.py b/tests/providers/cncf/kubernetes/triggers/test_job.py new file mode 100644 index 0000000000000..6124f5471c889 --- /dev/null +++ b/tests/providers/cncf/kubernetes/triggers/test_job.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.cncf.kubernetes.triggers.job import KubernetesJobTrigger +from airflow.triggers.base import TriggerEvent + +TRIGGER_PATH = "airflow.providers.cncf.kubernetes.triggers.job.{}" +TRIGGER_CLASS = TRIGGER_PATH.format("KubernetesJobTrigger") +HOOK_PATH = "airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook" +JOB_NAME = "test-job-name" +NAMESPACE = "default" +CONN_ID = "test_kubernetes_conn_id" +POLL_INTERVAL = 2 +CLUSTER_CONTEXT = "test-context" +CONFIG_FILE = "/path/to/config/file" +IN_CLUSTER = False + + +@pytest.fixture +def trigger(): + return KubernetesJobTrigger( + job_name=JOB_NAME, + job_namespace=NAMESPACE, + kubernetes_conn_id=CONN_ID, + poll_interval=POLL_INTERVAL, + cluster_context=CLUSTER_CONTEXT, + config_file=CONFIG_FILE, + in_cluster=IN_CLUSTER, + ) + + +class TestKubernetesJobTrigger: + def test_serialize(self, trigger): + classpath, kwargs_dict = trigger.serialize() + + assert classpath == TRIGGER_CLASS + assert kwargs_dict == { + "job_name": JOB_NAME, + "job_namespace": NAMESPACE, + "kubernetes_conn_id": CONN_ID, + "poll_interval": POLL_INTERVAL, + "cluster_context": CLUSTER_CONTEXT, + "config_file": CONFIG_FILE, + "in_cluster": IN_CLUSTER, + } + + @pytest.mark.asyncio + @mock.patch(f"{TRIGGER_CLASS}.hook") + async def test_run_success(self, mock_hook, trigger): + mock_job = mock.MagicMock() + mock_job.metadata.name = JOB_NAME + mock_job.metadata.namespace = NAMESPACE + mock_hook.wait_until_job_complete.side_effect = mock.AsyncMock(return_value=mock_job) + + mock_is_job_failed = mock_hook.is_job_failed + mock_is_job_failed.return_value = False + + mock_job_dict = mock_job.to_dict.return_value + + event_actual = await trigger.run().asend(None) + + mock_hook.wait_until_job_complete.assert_called_once_with(name=JOB_NAME, namespace=NAMESPACE) + mock_job.to_dict.assert_called_once() + mock_is_job_failed.assert_called_once_with(job=mock_job) + assert event_actual == TriggerEvent( + { + "name": JOB_NAME, + "namespace": NAMESPACE, + "status": "success", + "message": "Job completed successfully", + "job": mock_job_dict, + } + ) + + @pytest.mark.asyncio + @mock.patch(f"{TRIGGER_CLASS}.hook") + async def test_run_fail(self, mock_hook, trigger): + mock_job = mock.MagicMock() + mock_job.metadata.name = JOB_NAME + mock_job.metadata.namespace = NAMESPACE + mock_hook.wait_until_job_complete.side_effect = mock.AsyncMock(return_value=mock_job) + + mock_is_job_failed = mock_hook.is_job_failed + mock_is_job_failed.return_value = "Error" + + mock_job_dict = mock_job.to_dict.return_value + + event_actual = await trigger.run().asend(None) + + mock_hook.wait_until_job_complete.assert_called_once_with(name=JOB_NAME, namespace=NAMESPACE) + mock_job.to_dict.assert_called_once() + mock_is_job_failed.assert_called_once_with(job=mock_job) + assert event_actual == TriggerEvent( + { + "name": JOB_NAME, + "namespace": NAMESPACE, + "status": "error", + "message": "Job failed with error: Error", + "job": mock_job_dict, + } + ) + + @mock.patch(TRIGGER_PATH.format("AsyncKubernetesHook")) + def test_hook(self, mock_hook, trigger): + hook_expected = mock_hook.return_value + + hook_actual = trigger.hook + + mock_hook.assert_called_once_with( + conn_id=CONN_ID, + in_cluster=IN_CLUSTER, + config_file=CONFIG_FILE, + cluster_context=CLUSTER_CONTEXT, + ) + assert hook_actual == hook_expected diff --git a/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py b/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py index a5c97d9c95fe4..0f17f57a15414 100644 --- a/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py +++ b/tests/system/providers/cncf/kubernetes/example_kubernetes_job.py @@ -57,11 +57,23 @@ update_job = KubernetesPatchJobOperator( task_id="update-job-task", namespace="default", - name="test-pi", + name=JOB_NAME, body={"spec": {"suspend": False}}, ) # [END howto_operator_update_job] + # [START howto_operator_k8s_job_deferrable] + k8s_job_def = KubernetesJobOperator( + task_id="job-task-def", + namespace="default", + image="perl:5.34.0", + cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"], + name=JOB_NAME + "-def", + wait_until_job_complete=True, + deferrable=True, + ) + # [END howto_operator_k8s_job_deferrable] + # [START howto_operator_delete_k8s_job] delete_job_task = KubernetesDeleteJobOperator( task_id="delete_job_task", @@ -70,7 +82,14 @@ ) # [END howto_operator_delete_k8s_job] + delete_job_task_def = KubernetesDeleteJobOperator( + task_id="delete_job_task_def", + name=JOB_NAME + "-def", + namespace=JOB_NAMESPACE, + ) + k8s_job >> update_job >> delete_job_task + k8s_job_def >> delete_job_task_def from tests.system.utils.watcher import watcher