Skip to content

Commit

Permalink
Implement deferrable mode for KubernetesJobOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov committed Mar 22, 2024
1 parent 095c5fe commit 64d7e2d
Show file tree
Hide file tree
Showing 9 changed files with 569 additions and 14 deletions.
40 changes: 37 additions & 3 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import asyncio
import contextlib
import json
import tempfile
Expand Down Expand Up @@ -582,13 +583,15 @@ def is_job_complete(self, job: V1Job) -> bool:
return False

@staticmethod
def is_job_failed(job: V1Job) -> bool:
def is_job_failed(job: V1Job) -> str | bool:
"""Check whether the given job is failed.
:return: Boolean indicating that the given job is failed.
:return: Error message if the job is failed, and False otherwise.
"""
conditions = job.status.conditions or []
return bool(next((c for c in conditions if c.type == "Failed" and c.status), None))
if fail_condition := next((c for c in conditions if c.type == "Failed" and c.status), None):
return fail_condition.reason
return False

def patch_namespaced_job(self, job_name: str, namespace: str, body: object) -> V1Job:
"""
Expand Down Expand Up @@ -769,3 +772,34 @@ async def read_logs(self, name: str, namespace: str):
except HTTPError:
self.log.exception("There was an error reading the kubernetes API.")
raise

async def get_job_status(self, name: str, namespace: str) -> V1Job:
"""
Get job's status object.
:param name: Name of the pod.
:param namespace: Name of the pod's namespace.
"""
async with self.get_conn() as connection:
v1_api = async_client.BatchV1Api(connection)
job: V1Job = await v1_api.read_namespaced_job_status(
name=name,
namespace=namespace,
)
return job

async def wait_until_job_complete(self, name: str, namespace: str, poll_interval: float = 10) -> V1Job:
"""Block job of specified name and namespace until it is complete or failed.
:param name: Name of Job to fetch.
:param namespace: Namespace of the Job.
:param poll_interval: Interval in seconds between polling the job status
:return: Job object
"""
while True:
self.log.info("Requesting status for the job '%s' ", name)
job: V1Job = await self.get_job_status(name=name, namespace=namespace)
if self.is_job_complete(job=job):
return job
self.log.info("The job '%s' is incomplete. Sleeping for %i sec.", name, poll_interval)
await asyncio.sleep(poll_interval)
46 changes: 42 additions & 4 deletions airflow/providers/cncf/kubernetes/operators/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
Expand All @@ -37,6 +38,7 @@
)
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator, merge_objects
from airflow.providers.cncf.kubernetes.triggers.job import KubernetesJobTrigger
from airflow.utils import yaml
from airflow.utils.context import Context

Expand Down Expand Up @@ -74,6 +76,8 @@ class KubernetesJobOperator(KubernetesPodOperator):
Failed). Default is False.
:param job_poll_interval: Interval in seconds between polling the job status. Default is 10.
Used if the parameter `wait_until_job_complete` set True.
:param deferrable: Run operator in the deferrable mode. Note that the parameter
`wait_until_job_complete` must be set True.
"""

template_fields: Sequence[str] = tuple({"job_template_file"} | set(KubernetesPodOperator.template_fields))
Expand All @@ -93,6 +97,7 @@ def __init__(
ttl_seconds_after_finished: int | None = None,
wait_until_job_complete: bool = False,
job_poll_interval: float = 10,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -110,6 +115,7 @@ def __init__(
self.ttl_seconds_after_finished = ttl_seconds_after_finished
self.wait_until_job_complete = wait_until_job_complete
self.job_poll_interval = job_poll_interval
self.deferrable = deferrable

@cached_property
def _incluster_namespace(self):
Expand Down Expand Up @@ -139,6 +145,11 @@ def create_job(self, job_request_obj: k8s.V1Job) -> k8s.V1Job:
return job_request_obj

def execute(self, context: Context):
if self.deferrable and not self.wait_until_job_complete:
self.log.warning(
"Deferrable mode is available only with parameter `wait_until_job_complete=True`. "
"Please, set it up."
)
self.job_request_obj = self.build_job_request_obj(context)
self.job = self.create_job( # must set `self.job` for `on_kill`
job_request_obj=self.job_request_obj
Expand All @@ -148,17 +159,43 @@ def execute(self, context: Context):
ti.xcom_push(key="job_name", value=self.job.metadata.name)
ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace)

if self.wait_until_job_complete and self.deferrable:
self.execute_deferrable()
return

if self.wait_until_job_complete:
self.job = self.hook.wait_until_job_complete(
job_name=self.job.metadata.name,
namespace=self.job.metadata.namespace,
job_poll_interval=self.job_poll_interval,
)
ti.xcom_push(
key="job", value=self.hook.batch_v1_client.api_client.sanitize_for_serialization(self.job)

ti.xcom_push(key="job", value=self.job.to_dict())
if self.wait_until_job_complete:
if error_message := self.hook.is_job_failed(job=self.job):
raise AirflowException(
f"Kubernetes job '{self.job.metadata.name}' is failed with error '{error_message}'"
)

def execute_deferrable(self):
self.defer(
trigger=KubernetesJobTrigger(
job_name=self.job.metadata.name, # type: ignore[union-attr]
job_namespace=self.job.metadata.namespace, # type: ignore[union-attr]
kubernetes_conn_id=self.kubernetes_conn_id,
cluster_context=self.cluster_context,
config_file=self.config_file,
in_cluster=self.in_cluster,
poll_interval=self.job_poll_interval,
),
method_name="execute_complete",
)
if self.hook.is_job_failed(job=self.job):
raise AirflowException(f"Kubernetes job '{self.job.metadata.name}' is failed")

def execute_complete(self, context: Context, event: dict, **kwargs):
ti = context["ti"]
ti.xcom_push(key="job", value=event["job"])
if event["status"] == "error":
raise AirflowException(event["message"])

@staticmethod
def deserialize_job_template_file(path: str) -> k8s.V1Job:
Expand Down Expand Up @@ -188,6 +225,7 @@ def on_kill(self) -> None:
kwargs = {
"name": job.metadata.name,
"namespace": job.metadata.namespace,
"job": self.hook.batch_v1_client.api_client.sanitize_for_serialization(self.job),
}
if self.termination_grace_period is not None:
kwargs.update(grace_period_seconds=self.termination_grace_period)
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/cncf/kubernetes/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ triggers:
- integration-name: Kubernetes
python-modules:
- airflow.providers.cncf.kubernetes.triggers.pod
- airflow.providers.cncf.kubernetes.triggers.job


connection-types:
Expand Down
101 changes: 101 additions & 0 deletions airflow/providers/cncf/kubernetes/triggers/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, AsyncIterator

from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from kubernetes.client import V1Job


class KubernetesJobTrigger(BaseTrigger):
"""
KubernetesJobTrigger run on the trigger worker to check the state of Job.
:param job_name: The name of the job.
:param job_namespace: The namespace of the job.
:param kubernetes_conn_id: The :ref:`kubernetes connection id <howto/connection:kubernetes>`
for the Kubernetes cluster.
:param cluster_context: Context that points to kubernetes cluster.
:param config_file: Path to kubeconfig file.
:param poll_interval: Polling period in seconds to check for the status.
:param in_cluster: run kubernetes client with in_cluster configuration.
"""

def __init__(
self,
job_name: str,
job_namespace: str,
kubernetes_conn_id: str | None = None,
poll_interval: float = 2,
cluster_context: str | None = None,
config_file: str | None = None,
in_cluster: bool | None = None,
):
super().__init__()
self.job_name = job_name
self.job_namespace = job_namespace
self.kubernetes_conn_id = kubernetes_conn_id
self.poll_interval = poll_interval
self.cluster_context = cluster_context
self.config_file = config_file
self.in_cluster = in_cluster

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize KubernetesCreateJobTrigger arguments and classpath."""
return (
"airflow.providers.cncf.kubernetes.triggers.job.KubernetesJobTrigger",
{
"job_name": self.job_name,
"job_namespace": self.job_namespace,
"kubernetes_conn_id": self.kubernetes_conn_id,
"poll_interval": self.poll_interval,
"cluster_context": self.cluster_context,
"config_file": self.config_file,
"in_cluster": self.in_cluster,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current job status and yield a TriggerEvent."""
job: V1Job = await self.hook.wait_until_job_complete(name=self.job_name, namespace=self.job_namespace)
job_dict = job.to_dict()
error_message = self.hook.is_job_failed(job=job)
yield TriggerEvent(
{
"name": job.metadata.name,
"namespace": job.metadata.namespace,
"status": "error" if error_message else "success",
"message": f"Job failed with error: {error_message}"
if error_message
else "Job completed successfully",
"job": job_dict,
}
)

@cached_property
def hook(self) -> AsyncKubernetesHook:
return AsyncKubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
cluster_context=self.cluster_context,
)
9 changes: 9 additions & 0 deletions docs/apache-airflow-providers-cncf-kubernetes/operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,15 @@ to ``~/.kube/config``. It also allows users to supply a template YAML file using
:start-after: [START howto_operator_k8s_job]
:end-before: [END howto_operator_k8s_job]

The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator` also supports deferrable mode:

.. exampleinclude:: /../../tests/system/providers/cncf/kubernetes/example_kubernetes_job.py
:language: python
:dedent: 4
:start-after: [START howto_operator_k8s_job_deferrable]
:end-before: [END howto_operator_k8s_job_deferrable]


Difference between ``KubernetesPodOperator`` and ``KubernetesJobOperator``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The :class:`~airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator` is operator for creating Job.
Expand Down
70 changes: 65 additions & 5 deletions tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,14 +469,20 @@ def test_get_job_status(self, mock_client, mock_kube_config_merger, mock_kube_co
([mock.MagicMock(type="Complete", status=True)], False),
([mock.MagicMock(type="Complete", status=False)], False),
([mock.MagicMock(type="Failed", status=False)], False),
([mock.MagicMock(type="Failed", status=True)], True),
([mock.MagicMock(type="Failed", status=True, reason="test reason 1")], "test reason 1"),
(
[mock.MagicMock(type="Complete", status=False), mock.MagicMock(type="Failed", status=True)],
True,
[
mock.MagicMock(type="Complete", status=False),
mock.MagicMock(type="Failed", status=True, reason="test reason 2"),
],
"test reason 2",
),
(
[mock.MagicMock(type="Complete", status=True), mock.MagicMock(type="Failed", status=True)],
True,
[
mock.MagicMock(type="Complete", status=True),
mock.MagicMock(type="Failed", status=True, reason="test reason 3"),
],
"test reason 3",
),
],
)
Expand Down Expand Up @@ -584,6 +590,8 @@ class TestAsyncKubernetesHook:
INCLUSTER_CONFIG_LOADER = "kubernetes_asyncio.config.incluster_config.InClusterConfigLoader"
KUBE_LOADER_CONFIG = "kubernetes_asyncio.config.kube_config.KubeConfigLoader"
KUBE_API = "kubernetes_asyncio.client.api.core_v1_api.CoreV1Api.{}"
KUBE_BATCH_API = "kubernetes_asyncio.client.api.batch_v1_api.BatchV1Api.{}"
KUBE_ASYNC_HOOK = HOOK_MODULE + ".AsyncKubernetesHook.{}"

@staticmethod
def mock_await_result(return_value):
Expand Down Expand Up @@ -765,3 +773,55 @@ async def test_read_logs(self, lib_method, kube_config_loader, caplog):
timestamps=True,
)
assert "Container logs from 2023-01-11 Some string logs..." in caplog.text

@pytest.mark.asyncio
@mock.patch(KUBE_BATCH_API.format("read_namespaced_job_status"))
async def test_get_job_status(self, lib_method, kube_config_loader):
lib_method.return_value = self.mock_await_result(None)

hook = AsyncKubernetesHook(
conn_id=None,
in_cluster=False,
config_file=None,
cluster_context=None,
)
await hook.get_job_status(
name=JOB_NAME,
namespace=NAMESPACE,
)

lib_method.assert_called_once()

@pytest.mark.asyncio
@mock.patch(HOOK_MODULE + ".asyncio.sleep")
@mock.patch(KUBE_ASYNC_HOOK.format("is_job_complete"))
@mock.patch(KUBE_ASYNC_HOOK.format("get_job_status"))
async def test_wait_until_job_complete(
self, mock_get_job_status, mock_is_job_complete, mock_sleep, kube_config_loader
):
mock_job_0, mock_job_1 = mock.MagicMock(), mock.MagicMock()
mock_get_job_status.side_effect = mock.AsyncMock(side_effect=[mock_job_0, mock_job_1])
mock_is_job_complete.side_effect = [False, True]

hook = AsyncKubernetesHook(
conn_id=None,
in_cluster=False,
config_file=None,
cluster_context=None,
)

job_actual = await hook.wait_until_job_complete(
name=JOB_NAME,
namespace=NAMESPACE,
poll_interval=10,
)

mock_get_job_status.assert_has_awaits(
[
mock.call(name=JOB_NAME, namespace=NAMESPACE),
mock.call(name=JOB_NAME, namespace=NAMESPACE),
]
)
mock_is_job_complete.assert_has_calls([mock.call(job=mock_job_0), mock.call(job=mock_job_1)])
mock_sleep.assert_awaited_once_with(10)
assert job_actual == mock_job_1
Loading

0 comments on commit 64d7e2d

Please sign in to comment.