Skip to content

Commit

Permalink
Add GKEStartKueueInsideClusterOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
VladaZakharova committed Jan 29, 2024
1 parent 6f41010 commit c23cbc5
Show file tree
Hide file tree
Showing 8 changed files with 823 additions and 22 deletions.
24 changes: 23 additions & 1 deletion airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from airflow.utils import yaml

if TYPE_CHECKING:
from kubernetes.client.models import V1Pod
from kubernetes.client.models import V1Deployment, V1Pod

LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..."

Expand Down Expand Up @@ -282,6 +282,10 @@ def api_client(self) -> client.ApiClient:
def core_v1_client(self) -> client.CoreV1Api:
return client.CoreV1Api(api_client=self.api_client)

@cached_property
def apps_v1_client(self) -> client.AppsV1Api:
return client.AppsV1Api(api_client=self.api_client)

@cached_property
def custom_object_client(self) -> client.CustomObjectsApi:
return client.CustomObjectsApi(api_client=self.api_client)
Expand Down Expand Up @@ -450,6 +454,24 @@ def get_namespaced_pod_list(
**kwargs,
)

def get_deployment_status(
self,
name: str,
namespace: str = "default",
**kwargs,
) -> V1Deployment:
"""Get status of existing Deployment.
:param name: Name of Deployment to retrieve
:param namespace: Deployment namespace
"""
try:
return self.apps_v1_client.read_namespaced_deployment_status(
name=name, namespace=namespace, pretty=True, **kwargs
)
except Exception as exc:
raise exc


def _get_bool(val) -> bool | None:
"""Convert val to bool if can be done with certainty; if we cannot infer intention we return None."""
Expand Down
158 changes: 157 additions & 1 deletion airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@
from google.cloud import container_v1, exceptions # type: ignore[attr-defined]
from google.cloud.container_v1 import ClusterManagerAsyncClient, ClusterManagerClient
from google.cloud.container_v1.types import Cluster, Operation
from kubernetes import client
from kubernetes import client, utils
from kubernetes_asyncio import client as async_client
from kubernetes_asyncio.client.models import V1Deployment
from kubernetes_asyncio.config.kube_config import FileOrData
from urllib3.exceptions import HTTPError

from airflow import version
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.kube_client import _enable_tcp_keepalive
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol
from airflow.providers.google.common.consts import CLIENT_INFO
Expand Down Expand Up @@ -298,6 +300,130 @@ def get_cluster(
timeout=timeout,
)

def check_cluster_autoscaling_ability(self, cluster: Cluster | dict):
"""
Helper method to check if the specified Cluster has ability to autoscale.
Cluster should be Autopilot, with Node Auto-provisioning or regular auto-scaled node pools.
Returns True if the Cluster supports autoscaling, otherwise returns False.
:param cluster: The Cluster object.
"""
if isinstance(cluster, Cluster):
cluster_dict_representation = Cluster.to_dict(cluster)
elif not isinstance(cluster, dict):
raise AirflowException("cluster is not instance of Cluster proto or python dict")
else:
cluster_dict_representation = cluster

node_pools_autoscaled = False
for node_pool in cluster_dict_representation["node_pools"]:
try:
if node_pool["autoscaling"]["enabled"] is True:
node_pools_autoscaled = True
break
except KeyError:
self.log.info("No autoscaling enabled in Node pools level.")
break
if (
cluster_dict_representation["autopilot"]["enabled"]
or cluster_dict_representation["autoscaling"]["enable_node_autoprovisioning"]
or node_pools_autoscaled
):
return True
else:
return False


class GKEDeploymentHook(GoogleBaseHook, KubernetesHook):
"""Google Kubernetes Engine Deployment APIs."""

def __init__(
self,
cluster_url: str,
ssl_ca_cert: str,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self._cluster_url = cluster_url
self._ssl_ca_cert = ssl_ca_cert

@cached_property
def api_client(self) -> client.ApiClient:
return self.get_conn()

@cached_property
def core_v1_client(self) -> client.CoreV1Api:
return client.CoreV1Api(self.api_client)

@cached_property
def batch_v1_client(self) -> client.BatchV1Api:
return client.BatchV1Api(self.api_client)

@cached_property
def apps_v1_client(self) -> client.AppsV1Api:
return client.AppsV1Api(api_client=self.api_client)

def get_conn(self) -> client.ApiClient:
configuration = self._get_config()
configuration.refresh_api_key_hook = self._refresh_api_key_hook
return client.ApiClient(configuration)

def _refresh_api_key_hook(self, configuration: client.configuration.Configuration):
configuration.api_key = {"authorization": self._get_token(self.get_credentials())}

def _get_config(self) -> client.configuration.Configuration:
configuration = client.Configuration(
host=self._cluster_url,
api_key_prefix={"authorization": "Bearer"},
api_key={"authorization": self._get_token(self.get_credentials())},
)
configuration.ssl_ca_cert = FileOrData(
{
"certificate-authority-data": self._ssl_ca_cert,
},
file_key_name="certificate-authority",
).as_file()
return configuration

@staticmethod
def _get_token(creds: google.auth.credentials.Credentials) -> str:
if creds.token is None or creds.expired:
auth_req = google_requests.Request()
creds.refresh(auth_req)
return creds.token

def check_kueue_deployment_running(self, name, namespace):
timeout = 300
polling_period_seconds = 2

while timeout is None or timeout > 0:
try:
deployment = self.get_deployment_status(name=name, namespace=namespace)
deployment_status = V1Deployment.to_dict(deployment)["status"]
replicas = deployment_status["replicas"]
ready_replicas = deployment_status["ready_replicas"]
unavailable_replicas = deployment_status["unavailable_replicas"]
if (
replicas is not None
and ready_replicas is not None
and unavailable_replicas is None
and replicas == ready_replicas
):
return
else:
self.log.info("Waiting until Deployment will be ready...")
time.sleep(polling_period_seconds)
except Exception as e:
self.log.exception("Exception occurred while checking for Deployment status.")
raise e

if timeout is not None:
timeout -= polling_period_seconds

raise AirflowException("Deployment timed out")


class GKEAsyncHook(GoogleBaseAsyncHook):
"""Asynchronous client of GKE."""
Expand Down Expand Up @@ -430,6 +556,36 @@ def _get_token(creds: google.auth.credentials.Credentials) -> str:
creds.refresh(auth_req)
return creds.token

def apply_from_yaml_file(
self,
yaml_file: str | None = None,
yaml_objects: list[dict] | None = None,
verbose: bool = False,
namespace: str = "default",
):
"""
Perform an action from a yaml file on a Pod.
This is done until the given Pod reaches given State, or raises an error.
:param yaml_file: Contains the path to yaml file.
:param yaml_objects: List of YAML objects; used instead of reading the yaml_file.
:param verbose: If True, print confirmation from create action. Default is False.
:param namespace: Contains the namespace to create all
resources inside. The namespace must preexist otherwise the resource creation will fail. If the API
object in the yaml file already contains a namespace definition this parameter has no effect.
"""
k8s_client = self.get_conn()

utils.create_from_yaml(
k8s_client=k8s_client,
yaml_objects=yaml_objects,
yaml_file=yaml_file,
verbose=verbose,
namespace=namespace,
)

def get_pod(self, name: str, namespace: str) -> V1Pod:
"""Get a pod object.
Expand Down
Loading

0 comments on commit c23cbc5

Please sign in to comment.