Skip to content

Commit

Permalink
Implement KubernetesInstallKueueOperator +
Browse files Browse the repository at this point in the history
KubernetesStartKueueJobOperator and refactor GKE operatores
  • Loading branch information
moiseenkov committed Dec 2, 2024
1 parent 35000c9 commit 1967094
Show file tree
Hide file tree
Showing 11 changed files with 1,716 additions and 2,440 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from typing import TYPE_CHECKING, Any

import aiofiles
import requests
import tenacity
from asgiref.sync import sync_to_async
from kubernetes import client, config, watch
from kubernetes import client, config, utils, watch
from kubernetes.client.models import V1Deployment
from kubernetes.config import ConfigException
from kubernetes_asyncio import client as async_client, config as async_config
from urllib3.exceptions import HTTPError
Expand All @@ -47,7 +49,7 @@

if TYPE_CHECKING:
from kubernetes.client import V1JobList
from kubernetes.client.models import V1Deployment, V1Job, V1Pod
from kubernetes.client.models import V1Job, V1Pod

LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..."

Expand Down Expand Up @@ -489,12 +491,9 @@ def get_deployment_status(
:param name: Name of Deployment to retrieve
:param namespace: Deployment namespace
"""
try:
return self.apps_v1_client.read_namespaced_deployment_status(
name=name, namespace=namespace, pretty=True, **kwargs
)
except Exception as exc:
raise exc
return self.apps_v1_client.read_namespaced_deployment_status(
name=name, namespace=namespace, pretty=True, **kwargs
)

@tenacity.retry(
stop=tenacity.stop_after_attempt(3),
Expand Down Expand Up @@ -644,6 +643,71 @@ def patch_namespaced_job(self, job_name: str, namespace: str, body: object) -> V
body=body,
)

def apply_from_yaml_file(
self,
api_client: Any = None,
yaml_file: str | None = None,
yaml_objects: list[dict] | None = None,
verbose: bool = False,
namespace: str = "default",
):
"""
Perform an action from a yaml file.
:param api_client: A Kubernetes client application.
:param yaml_file: Contains the path to yaml file.
:param yaml_objects: List of YAML objects; used instead of reading the yaml_file.
:param verbose: If True, print confirmation from create action. Default is False.
:param namespace: Contains the namespace to create all resources inside. The namespace must
preexist otherwise the resource creation will fail.
"""
utils.create_from_yaml(
k8s_client=api_client or self.api_client,
yaml_objects=yaml_objects,
yaml_file=yaml_file,
verbose=verbose,
namespace=namespace or self.get_namespace(),
)

def check_kueue_deployment_running(
self, name: str, namespace: str, timeout: float = 300.0, polling_period_seconds: float = 2.0
) -> None:
_timeout = timeout
while _timeout > 0:
try:
deployment = self.get_deployment_status(name=name, namespace=namespace)
except Exception as e:
self.log.exception("Exception occurred while checking for Deployment status.")
raise e

deployment_status = V1Deployment.to_dict(deployment)["status"]
replicas = deployment_status["replicas"]
ready_replicas = deployment_status["ready_replicas"]
unavailable_replicas = deployment_status["unavailable_replicas"]
if (
replicas is not None
and ready_replicas is not None
and unavailable_replicas is None
and replicas == ready_replicas
):
return
else:
self.log.info("Waiting until Deployment will be ready...")
sleep(polling_period_seconds)

_timeout -= polling_period_seconds

raise AirflowException("Deployment timed out")

@staticmethod
def get_yaml_content_from_file(kueue_yaml_url) -> list[dict]:
"""Download content of YAML file and separate it into several dictionaries."""
response = requests.get(kueue_yaml_url, allow_redirects=True)
if response.status_code != 200:
raise AirflowException("Was not able to read the yaml file from given URL")

return list(yaml.safe_load_all(response.text))


def _get_bool(val) -> bool | None:
"""Convert val to bool if can be done with certainty; if we cannot infer intention we return None."""
Expand Down
105 changes: 105 additions & 0 deletions providers/src/airflow/providers/cncf/kubernetes/operators/kueue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Manage a Kubernetes Kueue."""

from __future__ import annotations

import json
import warnings
from collections.abc import Sequence
from functools import cached_property

from kubernetes.utils import FailToCreateError

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator


class KubernetesInstallKueueOperator(BaseOperator):
"""
Installs a Kubernetes Kueue.
:param kueue_version: The Kubernetes Kueue version to install.
:param kubernetes_conn_id: The :ref:`kubernetes connection id <howto/connection:kubernetes>`
for the Kubernetes cluster.
"""

template_fields: Sequence[str] = (
"kueue_version",
"kubernetes_conn_id",
)

def __init__(self, kueue_version: str, kubernetes_conn_id: str = "kubernetes_default", *args, **kwargs):
super().__init__(*args, **kwargs)
self.kubernetes_conn_id = kubernetes_conn_id
self.kueue_version = kueue_version
self._kueue_yaml_url = (
f"https://github.com/kubernetes-sigs/kueue/releases/download/{self.kueue_version}/manifests.yaml"
)

@cached_property
def hook(self) -> KubernetesHook:
return KubernetesHook(conn_id=self.kubernetes_conn_id)

def execute(self, context):
yaml_objects = self.hook.get_yaml_content_from_file(kueue_yaml_url=self._kueue_yaml_url)
try:
self.hook.apply_from_yaml_file(yaml_objects=yaml_objects)
except FailToCreateError as ex:
error_bodies = [json.loads(e.body) for e in ex.api_exceptions]
if next((e for e in error_bodies if e.get("reason") == "AlreadyExists"), None):
self.log.info("Kueue is already enabled for the cluster")

if errors := [e for e in error_bodies if e.get("reason") != "AlreadyExists"]:
error_message = "\n".join(e.get("body") for e in errors)
raise AirflowException(error_message)
return

self.hook.check_kueue_deployment_running(name="kueue-controller-manager", namespace="kueue-system")
self.log.info("Kueue installed successfully!")


class KubernetesStartKueueJobOperator(KubernetesJobOperator):
"""
Executes a Kubernetes Job in Kueue.
:param queue_name: The name of the Queue in the cluster
"""

template_fields = tuple({"queue_name"} | set(KubernetesJobOperator.template_fields))

def __init__(self, queue_name: str, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.queue_name = queue_name

if self.suspend is False:
raise AirflowException(
"The `suspend` parameter can't be False. If you want to use Kueue for running Job"
" in a Kubernetes cluster, set the `suspend` parameter to True.",
)
elif self.suspend is None:
warnings.warn(
f"You have not set parameter `suspend` in class {self.__class__.__name__}. "
"For running a Job in Kueue the `suspend` parameter should set to True.",
UserWarning,
stacklevel=2,
)
self.suspend = True
self.labels.update({"kueue.x-k8s.io/queue-name": self.queue_name})
self.annotations.update({"kueue.x-k8s.io/queue-name": self.queue_name})
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,9 @@ def __init__(
self.remote_pod: k8s.V1Pod | None = None
self.log_pod_spec_on_failure = log_pod_spec_on_failure
self.on_finish_action = OnFinishAction(on_finish_action)
self.is_delete_operator_pod = self.on_finish_action == OnFinishAction.DELETE_POD
# The `is_delete_operator_pod` parameter should have been removed in provider version 10.0.0.
# TODO: remove it from here and from the operator's parameters list when the next major version bumped
self._is_delete_operator_pod = self.on_finish_action == OnFinishAction.DELETE_POD
self.termination_message_policy = termination_message_policy
self.active_deadline_seconds = active_deadline_seconds
self.logging_interval = logging_interval
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import json
import time
from collections.abc import Sequence
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
Expand All @@ -33,8 +33,7 @@
from google.cloud import exceptions # type: ignore[attr-defined]
from google.cloud.container_v1 import ClusterManagerAsyncClient, ClusterManagerClient
from google.cloud.container_v1.types import Cluster, Operation
from kubernetes import client, utils
from kubernetes.client.models import V1Deployment
from kubernetes import client
from kubernetes_asyncio import client as async_client
from kubernetes_asyncio.config.kube_config import FileOrData

Expand Down Expand Up @@ -434,38 +433,9 @@ def get_conn(self) -> client.ApiClient:
enable_tcp_keepalive=self.enable_tcp_keepalive,
).get_conn()

def check_kueue_deployment_running(self, name, namespace):
timeout = 300
polling_period_seconds = 2

while timeout is None or timeout > 0:
try:
deployment = self.get_deployment_status(name=name, namespace=namespace)
deployment_status = V1Deployment.to_dict(deployment)["status"]
replicas = deployment_status["replicas"]
ready_replicas = deployment_status["ready_replicas"]
unavailable_replicas = deployment_status["unavailable_replicas"]
if (
replicas is not None
and ready_replicas is not None
and unavailable_replicas is None
and replicas == ready_replicas
):
return
else:
self.log.info("Waiting until Deployment will be ready...")
time.sleep(polling_period_seconds)
except Exception as e:
self.log.exception("Exception occurred while checking for Deployment status.")
raise e

if timeout is not None:
timeout -= polling_period_seconds

raise AirflowException("Deployment timed out")

def apply_from_yaml_file(
self,
api_client: Any = None,
yaml_file: str | None = None,
yaml_objects: list[dict] | None = None,
verbose: bool = False,
Expand All @@ -474,18 +444,17 @@ def apply_from_yaml_file(
"""
Perform an action from a yaml file.
:param api_client: A Kubernetes client application.
:param yaml_file: Contains the path to yaml file.
:param yaml_objects: List of YAML objects; used instead of reading the yaml_file.
:param verbose: If True, print confirmation from create action. Default is False.
:param namespace: Contains the namespace to create all resources inside. The namespace must
preexist otherwise the resource creation will fail.
"""
k8s_client = self.get_conn()

utils.create_from_yaml(
k8s_client=k8s_client,
yaml_objects=yaml_objects,
super().apply_from_yaml_file(
api_client=api_client or self.get_conn(),
yaml_file=yaml_file,
yaml_objects=yaml_objects,
verbose=verbose,
namespace=namespace,
)
Expand Down
Loading

0 comments on commit 1967094

Please sign in to comment.