diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index f72a0bc2ab7dc..3a8faa98f6752 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -37,7 +37,7 @@ from airflow.utils import yaml if TYPE_CHECKING: - from kubernetes.client.models import V1Pod + from kubernetes.client.models import V1Deployment, V1Pod LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..." @@ -282,6 +282,10 @@ def api_client(self) -> client.ApiClient: def core_v1_client(self) -> client.CoreV1Api: return client.CoreV1Api(api_client=self.api_client) + @cached_property + def apps_v1_client(self) -> client.AppsV1Api: + return client.AppsV1Api(api_client=self.api_client) + @cached_property def custom_object_client(self) -> client.CustomObjectsApi: return client.CustomObjectsApi(api_client=self.api_client) @@ -450,6 +454,24 @@ def get_namespaced_pod_list( **kwargs, ) + def get_deployment_status( + self, + name: str, + namespace: str = "default", + **kwargs, + ) -> V1Deployment: + """Get status of existing Deployment. + + :param name: Name of Deployment to retrieve + :param namespace: Deployment namespace + """ + try: + return self.apps_v1_client.read_namespaced_deployment_status( + name=name, namespace=namespace, pretty=True, **kwargs + ) + except Exception as exc: + raise exc + def _get_bool(val) -> bool | None: """Convert val to bool if can be done with certainty; if we cannot infer intention we return None.""" diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py b/airflow/providers/google/cloud/hooks/kubernetes_engine.py index 4fb7534d878f8..2f5f64134022b 100644 --- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py +++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py @@ -41,13 +41,15 @@ from google.cloud import container_v1, exceptions # type: ignore[attr-defined] from google.cloud.container_v1 import ClusterManagerAsyncClient, ClusterManagerClient from google.cloud.container_v1.types import Cluster, Operation -from kubernetes import client +from kubernetes import client, utils from kubernetes_asyncio import client as async_client +from kubernetes.client.models import V1Deployment from kubernetes_asyncio.config.kube_config import FileOrData from urllib3.exceptions import HTTPError from airflow import version from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.providers.cncf.kubernetes.kube_client import _enable_tcp_keepalive from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol from airflow.providers.google.common.consts import CLIENT_INFO @@ -298,6 +300,130 @@ def get_cluster( timeout=timeout, ) + def check_cluster_autoscaling_ability(self, cluster: Cluster | dict): + """ + Helper method to check if the specified Cluster has ability to autoscale. + + Cluster should be Autopilot, with Node Auto-provisioning or regular auto-scaled node pools. + Returns True if the Cluster supports autoscaling, otherwise returns False. + + :param cluster: The Cluster object. + """ + if isinstance(cluster, Cluster): + cluster_dict_representation = Cluster.to_dict(cluster) + elif not isinstance(cluster, dict): + raise AirflowException("cluster is not instance of Cluster proto or python dict") + else: + cluster_dict_representation = cluster + + node_pools_autoscaled = False + for node_pool in cluster_dict_representation["node_pools"]: + try: + if node_pool["autoscaling"]["enabled"] is True: + node_pools_autoscaled = True + break + except KeyError: + self.log.info("No autoscaling enabled in Node pools level.") + break + if ( + cluster_dict_representation["autopilot"]["enabled"] + or cluster_dict_representation["autoscaling"]["enable_node_autoprovisioning"] + or node_pools_autoscaled + ): + return True + else: + return False + + +class GKEDeploymentHook(GoogleBaseHook, KubernetesHook): + """Google Kubernetes Engine Deployment APIs.""" + + def __init__( + self, + cluster_url: str, + ssl_ca_cert: str, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._cluster_url = cluster_url + self._ssl_ca_cert = ssl_ca_cert + + @cached_property + def api_client(self) -> client.ApiClient: + return self.get_conn() + + @cached_property + def core_v1_client(self) -> client.CoreV1Api: + return client.CoreV1Api(self.api_client) + + @cached_property + def batch_v1_client(self) -> client.BatchV1Api: + return client.BatchV1Api(self.api_client) + + @cached_property + def apps_v1_client(self) -> client.AppsV1Api: + return client.AppsV1Api(api_client=self.api_client) + + def get_conn(self) -> client.ApiClient: + configuration = self._get_config() + configuration.refresh_api_key_hook = self._refresh_api_key_hook + return client.ApiClient(configuration) + + def _refresh_api_key_hook(self, configuration: client.configuration.Configuration): + configuration.api_key = {"authorization": self._get_token(self.get_credentials())} + + def _get_config(self) -> client.configuration.Configuration: + configuration = client.Configuration( + host=self._cluster_url, + api_key_prefix={"authorization": "Bearer"}, + api_key={"authorization": self._get_token(self.get_credentials())}, + ) + configuration.ssl_ca_cert = FileOrData( + { + "certificate-authority-data": self._ssl_ca_cert, + }, + file_key_name="certificate-authority", + ).as_file() + return configuration + + @staticmethod + def _get_token(creds: google.auth.credentials.Credentials) -> str: + if creds.token is None or creds.expired: + auth_req = google_requests.Request() + creds.refresh(auth_req) + return creds.token + + def check_kueue_deployment_running(self, name, namespace): + timeout = 300 + polling_period_seconds = 2 + + while timeout is None or timeout > 0: + try: + deployment = self.get_deployment_status(name=name, namespace=namespace) + deployment_status = V1Deployment.to_dict(deployment)["status"] + replicas = deployment_status["replicas"] + ready_replicas = deployment_status["ready_replicas"] + unavailable_replicas = deployment_status["unavailable_replicas"] + if ( + replicas is not None + and ready_replicas is not None + and unavailable_replicas is None + and replicas == ready_replicas + ): + return + else: + self.log.info("Waiting until Deployment will be ready...") + time.sleep(polling_period_seconds) + except Exception as e: + self.log.exception("Exception occurred while checking for Deployment status.") + raise e + + if timeout is not None: + timeout -= polling_period_seconds + + raise AirflowException("Deployment timed out") + class GKEAsyncHook(GoogleBaseAsyncHook): """Asynchronous client of GKE.""" @@ -430,6 +556,34 @@ def _get_token(creds: google.auth.credentials.Credentials) -> str: creds.refresh(auth_req) return creds.token + def apply_from_yaml_file( + self, + yaml_file: str | None = None, + yaml_objects: list[dict] | None = None, + verbose: bool = False, + namespace: str = "default", + ): + """ + Perform an action from a yaml file on a Pod. + This is done until the given Pod reaches given State, or raises an error. + + :param yaml_file: Contains the path to yaml file. + :param yaml_objects: List of YAML objects; used instead of reading the yaml_file. + :param verbose: If True, print confirmation from create action. Default is False. + :param namespace: Contains the namespace to create all + resources inside. The namespace must preexist otherwise the resource creation will fail. If the API + object in the yaml file already contains a namespace definition this parameter has no effect. + """ + k8s_client = self.get_conn() + + utils.create_from_yaml( + k8s_client=k8s_client, + yaml_objects=yaml_objects, + yaml_file=yaml_file, + verbose=verbose, + namespace=namespace, + ) + def get_pod(self, name: str, namespace: str) -> V1Pod: """Get a pod object. diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index 2d2bf7337d62f..61e366ee54884 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -18,18 +18,22 @@ """This module contains Google Kubernetes Engine operators.""" from __future__ import annotations +import re import warnings from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence +import requests +import yaml from google.api_core.exceptions import AlreadyExists from google.cloud.container_v1.types import Cluster +from kubernetes.utils.create_from_yaml import FailToCreateError from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction -from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEHook, GKEPodHook +from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEDeploymentHook, GKEHook, GKEPodHook from airflow.providers.google.cloud.links.kubernetes_engine import ( KubernetesEngineClusterLink, KubernetesEnginePodLink, @@ -46,6 +50,45 @@ KUBE_CONFIG_ENV_VAR = "KUBECONFIG" +class GKEClusterAuthDetails: + """ + Helper for fetching information about cluster for connecting. + + :param cluster_name: The name of the Google Kubernetes Engine cluster the pod should be spawned in. + :param project_id: The Google Developers Console project id. + :param use_internal_ip: Use the internal IP address as the endpoint. + :param cluster_hook: airflow hook for working with kubernetes cluster. + """ + + def __init__( + self, + cluster_name, + project_id, + use_internal_ip, + cluster_hook, + ): + self.cluster_name = cluster_name + self.project_id = project_id + self.use_internal_ip = use_internal_ip + self.cluster_hook = cluster_hook + self._cluster_url = None + self._ssl_ca_cert = None + + def fetch_cluster_info(self) -> tuple[str, str | None]: + """Fetches cluster info for connecting to it.""" + cluster = self.cluster_hook.get_cluster( + name=self.cluster_name, + project_id=self.project_id, + ) + + if not self.use_internal_ip: + self._cluster_url = f"https://{cluster.endpoint}" + else: + self._cluster_url = f"https://{cluster.private_cluster_config.private_endpoint}" + self._ssl_ca_cert = cluster.master_auth.cluster_ca_certificate + return self._cluster_url, self._ssl_ca_cert + + class GKEDeleteClusterOperator(GoogleCloudBaseOperator): """ Deletes the cluster, including the Kubernetes endpoint and all worker nodes. @@ -387,6 +430,153 @@ def _get_hook(self) -> GKEHook: return self._hook +class GKEStartKueueInsideClusterOperator(GoogleCloudBaseOperator): + """ + Installs Kueue of specific version inside Cluster. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:GKEStartKueueInsideClusterOperator` + + .. seealso:: + For more details about Kueue have a look at the reference: + https://kueue.sigs.k8s.io/docs/overview/ + + :param project_id: The Google Developers Console [project ID or project number]. + :param location: The name of the Google Kubernetes Engine zone or region in which the cluster resides. + :param cluster_name: The Cluster name in which to install Kueue. + :param kueue_version: Version of Kueue to install. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "project_id", + "location", + "kueue_version", + "cluster_name", + "gcp_conn_id", + "impersonation_chain", + ) + operator_extra_links = (KubernetesEngineClusterLink(),) + + def __init__( + self, + *, + location: str, + cluster_name: str, + kueue_version: str, + use_internal_ip: bool = False, + project_id: str | None = None, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.cluster_name = cluster_name + self.kueue_version = kueue_version + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.use_internal_ip = use_internal_ip + self._kueue_yaml_url = ( + f"https://github.com/kubernetes-sigs/kueue/releases/download/{self.kueue_version}/manifests.yaml" + ) + + @cached_property + def cluster_hook(self) -> GKEHook: + return GKEHook( + gcp_conn_id=self.gcp_conn_id, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + @cached_property + def deployment_hook(self) -> GKEDeploymentHook: + if self._cluster_url is None or self._ssl_ca_cert is None: + raise AttributeError( + "Cluster url and ssl_ca_cert should be defined before using self.hook method. " + "Try to use self.get_kube_creds method", + ) + return GKEDeploymentHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + cluster_url=self._cluster_url, + ssl_ca_cert=self._ssl_ca_cert, + ) + + @cached_property + def pod_hook(self) -> GKEPodHook: + if self._cluster_url is None or self._ssl_ca_cert is None: + raise AttributeError( + "Cluster url and ssl_ca_cert should be defined before using self.hook method. " + "Try to use self.get_kube_creds method", + ) + return GKEPodHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + cluster_url=self._cluster_url, + ssl_ca_cert=self._ssl_ca_cert, + ) + + @staticmethod + def _get_yaml_content_from_file(kueue_yaml_url) -> list[dict]: + """Helper method to download content of YAML file and separate it into several dictionaries.""" + response = requests.get(kueue_yaml_url, allow_redirects=True) + yaml_dicts = [] + if response.status_code == 200: + yaml_data = response.text + documents = re.split(r"---\n", yaml_data) + + for document in documents: + document_dict = yaml.safe_load(document) + yaml_dicts.append(document_dict) + else: + raise AirflowException("Was not able to read the yaml file from given URL") + return yaml_dicts + + def execute(self, context: Context): + cluster_info = GKEClusterAuthDetails( + project_id=self.project_id, + cluster_name=self.cluster_name, + use_internal_ip=self.use_internal_ip, + cluster_hook=self.cluster_hook, + ) + self._cluster_url, self._ssl_ca_cert = cluster_info.fetch_cluster_info() + + cluster = self.cluster_hook.get_cluster( + name=self.cluster_name, + project_id=self.project_id, + ) + KubernetesEngineClusterLink.persist(context=context, task_instance=self, cluster=cluster) + + yaml_objects = self._get_yaml_content_from_file(kueue_yaml_url=self._kueue_yaml_url) + + if self.cluster_hook.check_cluster_autoscaling_ability(cluster=cluster): + try: + self.pod_hook.apply_from_yaml_file(yaml_objects=yaml_objects) + + self.deployment_hook.check_kueue_deployment_running( + name="kueue-controller-manager", namespace="kueue-system" + ) + + self.log.info("Kueue installed successfully!") + except FailToCreateError: + self.log.info("Kueue is already enabled for the cluster") + else: + self.log.info( + "Cluster doesn't have ability to autoscale, will not install Kueue inside. Aborting" + ) + + class GKEStartPodOperator(KubernetesPodOperator): """ Executes a task in a Kubernetes pod in the specified Google Kubernetes Engine cluster. @@ -495,8 +685,14 @@ def __init__( self.use_internal_ip = use_internal_ip self.pod: V1Pod | None = None - self._ssl_ca_cert: str | None = None - self._cluster_url: str | None = None + + cluster_info = GKEClusterAuthDetails( + project_id=self.project_id, + cluster_name=self.cluster_name, + use_internal_ip=self.use_internal_ip, + cluster_hook=self.cluster_hook, + ) + self._cluster_url, self._ssl_ca_cert = cluster_info.fetch_cluster_info() if self.gcp_conn_id is None: raise AirflowException( @@ -544,23 +740,8 @@ def hook(self) -> GKEPodHook: def execute(self, context: Context): """Executes process of creating pod and executing provided command inside it.""" - self.fetch_cluster_info() return super().execute(context) - def fetch_cluster_info(self) -> tuple[str, str | None]: - """Fetches cluster info for connecting to it.""" - cluster = self.cluster_hook.get_cluster( - name=self.cluster_name, - project_id=self.project_id, - ) - - if not self.use_internal_ip: - self._cluster_url = f"https://{cluster.endpoint}" - else: - self._cluster_url = f"https://{cluster.private_cluster_config.private_endpoint}" - self._ssl_ca_cert = cluster.master_auth.cluster_ca_certificate - return self._cluster_url, self._ssl_ca_cert - def invoke_defer_method(self): """Method to easily redefine triggers which are being used in child classes.""" trigger_start_time = utcnow() diff --git a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst index 74c396c9855d0..7663ce48118ff 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/kubernetes_engine.rst @@ -71,6 +71,25 @@ lot less resources wasted on idle Operators or Sensors: :end-before: [END howto_operator_gke_create_cluster_async] +.. _howto/operator:GKEStartKueueInsideClusterOperator: + +Install Kueue of specific version inside Cluster +"""""""""""""""""""""""""""""""""""""""""""""""" + +Kueue is a Cloud Native Job scheduler that works with the default Kubernetes scheduler, the Job controller, +and the cluster autoscaler to provide an end-to-end batch system. Kueue implements Job queueing, deciding when +Jobs should wait and when they should start, based on quotas and a hierarchy for sharing resources fairly among teams. +Kueue supports Autopilot clusters, Standard GKE with Node Auto-provisioning and regular autoscaled node pools. +To install and use Kueue on your cluster with the help of +:class:`~airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartKueueInsideClusterOperator` +as shown in this example: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py + :language: python + :start-after: [START howto_operator_gke_install_kueue] + :end-before: [END howto_operator_gke_install_kueue] + + .. _howto/operator:GKEDeleteClusterOperator: Delete GKE cluster diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 1dc8d5693b398..4eccdc614ff4f 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -113,9 +113,11 @@ autoenv autogenerated automl AutoMlClient +autoprovisioning autorestart Autoscale autoscale +autoscaled autoscaler autoscaling avp @@ -904,6 +906,7 @@ kubeconfig Kubernetes kubernetes KubernetesPodOperator +Kueue Kusto kv kwarg @@ -1279,6 +1282,7 @@ QuboleCheckHook Quboles queryParameters querystring +queueing quickstart quotechar rabbitmq diff --git a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py index 954f3e018ebff..61abe8e8f84ea 100644 --- a/tests/providers/google/cloud/hooks/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/hooks/test_kubernetes_engine.py @@ -24,10 +24,12 @@ import pytest from google.cloud.container_v1 import ClusterManagerAsyncClient from google.cloud.container_v1.types import Cluster +from kubernetes.client.models import V1Deployment, V1DeploymentStatus from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.kubernetes_engine import ( GKEAsyncHook, + GKEDeploymentHook, GKEHook, GKEPodAsyncHook, GKEPodHook, @@ -37,10 +39,12 @@ TASK_ID = "test-gke-cluster-operator" CLUSTER_NAME = "test-cluster" +NAMESPACE = "test-cluster-namespace" TEST_GCP_PROJECT_ID = "test-project" GKE_ZONE = "test-zone" BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" GKE_STRING = "airflow.providers.google.cloud.hooks.kubernetes_engine.{}" +K8S_HOOK = "airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook" CLUSTER_URL = "https://path.to.cluster" SSL_CA_CERT = "test-ssl-ca-cert" POD_NAME = "test-pod-name" @@ -49,6 +53,113 @@ GCP_CONN_ID = "test-gcp-conn-id" IMPERSONATE_CHAIN = ["impersonate", "this", "test"] OPERATION_NAME = "test-operation-name" +CLUSTER_TEST_AUTOPILOT = { + "name": "autopilot-cluster", + "initial_node_count": 1, + "autopilot": { + "enabled": True, + }, + "autoscaling": { + "enable_node_autoprovisioning": False, + }, + "node_pools": [ + { + "name": "pool", + "config": {"machine_type": "e2-standard-32", "disk_size_gb": 11}, + "initial_node_count": 2, + } + ], +} +CLUSTER_TEST_AUTOPROVISIONING = { + "name": "cluster_autoprovisioning", + "initial_node_count": 1, + "autopilot": { + "enabled": False, + }, + "node_pools": [ + { + "name": "pool", + "config": {"machine_type": "e2-standard-32", "disk_size_gb": 11}, + "initial_node_count": 2, + } + ], + "autoscaling": { + "enable_node_autoprovisioning": True, + "resource_limits": [ + {"resource_type": "cpu", "maximum": 1000000000}, + {"resource_type": "memory", "maximum": 1000000000}, + ], + }, +} +CLUSTER_TEST_AUTOSCALED = { + "name": "autoscaled_cluster", + "autopilot": { + "enabled": False, + }, + "node_pools": [ + { + "name": "autoscaled-pool", + "config": {"machine_type": "e2-standard-32", "disk_size_gb": 11}, + "initial_node_count": 2, + "autoscaling": { + "enabled": True, + "max_node_count": 10, + }, + } + ], + "autoscaling": { + "enable_node_autoprovisioning": False, + }, +} + +CLUSTER_TEST_REGULAR = { + "name": "regular_cluster", + "initial_node_count": 1, + "autopilot": { + "enabled": False, + }, + "node_pools": [ + { + "name": "autoscaled-pool", + "config": {"machine_type": "e2-standard-32", "disk_size_gb": 11}, + "initial_node_count": 2, + "autoscaling": { + "enabled": False, + }, + } + ], + "autoscaling": { + "enable_node_autoprovisioning": False, + }, +} +pods = { + "succeeded": { + "metadata": {"name": "test-pod", "namespace": "default"}, + "status": {"phase": "Succeeded"}, + }, + "pending": { + "metadata": {"name": "test-pod", "namespace": "default"}, + "status": {"phase": "Pending"}, + }, + "running": { + "metadata": {"name": "test-pod", "namespace": "default"}, + "status": {"phase": "Running"}, + }, +} +NOT_READY_DEPLOYMENT = V1Deployment( + status=V1DeploymentStatus( + observed_generation=1, + ready_replicas=None, + replicas=None, + unavailable_replicas=1, + updated_replicas=None, + ) +) +READY_DEPLOYMENT = V1Deployment( + status=V1DeploymentStatus( + observed_generation=1, ready_replicas=1, replicas=1, unavailable_replicas=None, updated_replicas=1 + ) +) @pytest.mark.db_test @@ -298,6 +409,82 @@ def test_wait_for_response_running(self, time_mock, operation_mock): operation_mock.assert_any_call(pending_op.name, project_id=TEST_GCP_PROJECT_ID) assert operation_mock.call_count == 2 + @pytest.mark.parametrize( + "cluster_obj, expected_result", + [ + (CLUSTER_TEST_AUTOPROVISIONING, True), + (CLUSTER_TEST_AUTOSCALED, True), + (CLUSTER_TEST_AUTOPILOT, True), + (CLUSTER_TEST_REGULAR, False), + ], + ) + def test_check_cluster_autoscaling_ability(self, cluster_obj, expected_result): + result = self.gke_hook.check_cluster_autoscaling_ability(cluster_obj) + assert result == expected_result + + +class TestGKEDeploymentHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.gke_hook = GKEDeploymentHook(gcp_conn_id="test", ssl_ca_cert=None, cluster_url=None) + self.gke_hook._client = mock.Mock() + + def refresh_token(request): + self.credentials.token = "New" + + self.credentials = mock.MagicMock() + self.credentials.token = "Old" + self.credentials.expired = False + self.credentials.refresh = refresh_token + + @mock.patch(GKE_STRING.format("google_requests.Request")) + def test_get_connection_update_hook_with_invalid_token(self, mock_request): + self.gke_hook._get_config = self._get_config + self.gke_hook.get_credentials = self._get_credentials + self.gke_hook.get_credentials().expired = True + the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn() + + the_client.configuration.refresh_api_key_hook(the_client.configuration) + + assert self.gke_hook.get_credentials().token == "New" + + @mock.patch(GKE_STRING.format("google_requests.Request")) + def test_get_connection_update_hook_with_valid_token(self, mock_request): + self.gke_hook._get_config = self._get_config + self.gke_hook.get_credentials = self._get_credentials + self.gke_hook.get_credentials().expired = False + the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn() + + the_client.configuration.refresh_api_key_hook(the_client.configuration) + + assert self.gke_hook.get_credentials().token == "Old" + + def _get_config(self): + return kubernetes.client.configuration.Configuration() + + def _get_credentials(self): + return self.credentials + + @mock.patch("kubernetes.client.AppsV1Api") + def test_check_kueue_deployment_running(self, gke_deployment_hook, caplog): + gke_deployment_hook.return_value.read_namespaced_deployment_status.side_effect = [ + NOT_READY_DEPLOYMENT, + READY_DEPLOYMENT, + ] + self.gke_hook.check_kueue_deployment_running(name=CLUSTER_NAME, namespace=NAMESPACE) + + assert "Waiting until Deployment will be ready..." in caplog.text + + @mock.patch("kubernetes.client.AppsV1Api") + def test_check_kueue_deployment_raise_exception(self, gke_deployment_hook, caplog): + gke_deployment_hook.return_value.read_namespaced_deployment_status.side_effect = ValueError() + with pytest.raises(ValueError): + self.gke_hook.check_kueue_deployment_running(name=CLUSTER_NAME, namespace=NAMESPACE) + + assert "Exception occurred while checking for Deployment status." in caplog.text + class TestGKEPodAsyncHook: @staticmethod diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index 7805804485682..6fee6dda33ada 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -22,14 +22,18 @@ from unittest import mock import pytest +from kubernetes.client.models import V1Deployment, V1DeploymentStatus +from kubernetes.utils.create_from_yaml import FailToCreateError from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import Connection from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction from airflow.providers.google.cloud.operators.kubernetes_engine import ( + GKEClusterAuthDetails, GKECreateClusterOperator, GKEDeleteClusterOperator, + GKEStartKueueInsideClusterOperator, GKEStartPodOperator, ) from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEStartPodTrigger @@ -61,14 +65,27 @@ KUBE_ENV_VAR = "KUBECONFIG" FILE_NAME = "/tmp/mock_name" KUB_OP_PATH = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.{}" -GKE_HOOK_MODULE_PATH = "airflow.providers.google.cloud.operators.kubernetes_engine" +GKE_HOOK_MODULE_PATH = "airflow.providers.google.cloud.hooks.kubernetes_engine" GKE_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEHook" +GKE_POD_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEPodHook" +GKE_DEPLOYMENT_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEDeploymentHook" KUB_OPERATOR_EXEC = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute" TEMP_FILE = "tempfile.NamedTemporaryFile" GKE_OP_PATH = "airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator" +GKE_CLUSTER_AUTH_DETAILS_PATH = ( + "airflow.providers.google.cloud.operators.kubernetes_engine.GKEClusterAuthDetails" +) CLUSTER_URL = "https://test-host" CLUSTER_PRIVATE_URL = "https://test-private-host" SSL_CA_CERT = "TEST_SSL_CA_CERT_CONTENT" +KUEUE_VERSION = "v0.5.1" +IMPERSONATION_CHAIN = "sa-@google.com" +USE_INTERNAL_API = False +READY_DEPLOYMENT = V1Deployment( + status=V1DeploymentStatus( + observed_generation=1, ready_replicas=1, replicas=1, unavailable_replicas=None, updated_replicas=1 + ) +) class TestGoogleCloudPlatformContainerOperator: @@ -306,7 +323,7 @@ def test_cluster_info(self, get_cluster_mock, use_internal_ip): "master_auth.cluster_ca_certificate": SSL_CA_CERT, } ) - gke_op = GKEStartPodOperator( + GKEStartPodOperator( project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, cluster_name=CLUSTER_NAME, @@ -316,13 +333,27 @@ def test_cluster_info(self, get_cluster_mock, use_internal_ip): image=IMAGE, use_internal_ip=use_internal_ip, ) - cluster_url, ssl_ca_cert = gke_op.fetch_cluster_info() + cluster_info = GKEClusterAuthDetails( + project_id=TEST_GCP_PROJECT_ID, + cluster_name=CLUSTER_NAME, + use_internal_ip=use_internal_ip, + cluster_hook=get_cluster_mock, + ) + cluster_url, ssl_ca_cert = cluster_info.fetch_cluster_info() assert cluster_url == CLUSTER_PRIVATE_URL if use_internal_ip else CLUSTER_URL assert ssl_ca_cert == SSL_CA_CERT @pytest.mark.db_test - def test_default_gcp_conn_id(self): + @mock.patch(f"{GKE_HOOK_PATH}.get_cluster") + def test_default_gcp_conn_id(self, get_cluster_mock): + get_cluster_mock.return_value = mock.MagicMock( + **{ + "endpoint": "test-host", + "private_cluster_config.private_endpoint": "test-private-host", + "master_auth.cluster_ca_certificate": SSL_CA_CERT, + } + ) gke_op = GKEStartPodOperator( project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, @@ -420,6 +451,135 @@ def test_on_finish_action_handler( assert op.__getattribute__(expected_attr) == expected_attributes[expected_attr] +class TestGKEStartKueueInsideClusterOperator: + def setup_method(self): + self.gke_op = GKEStartKueueInsideClusterOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + kueue_version=KUEUE_VERSION, + impersonation_chain=IMPERSONATION_CHAIN, + use_internal_ip=USE_INTERNAL_API, + ) + self.gke_op._cluster_url = CLUSTER_URL + self.gke_op._ssl_ca_cert = SSL_CA_CERT + + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + @mock.patch(GKE_DEPLOYMENT_HOOK_PATH) + def test_execute(self, mock_depl_hook, mock_hook, fetch_cluster_info_mock): + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False + mock_depl_hook.return_value.get_deployment_status.return_value = READY_DEPLOYMENT + self.gke_op.execute(context=mock.MagicMock()) + mock_hook.return_value.get_cluster.assert_called_once() + + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + @mock.patch(GKE_POD_HOOK_PATH) + @mock.patch(GKE_DEPLOYMENT_HOOK_PATH) + def test_execute_autoscaled_cluster( + self, mock_depl_hook, mock_pod_hook, mock_hook, fetch_cluster_info_mock, caplog + ): + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + mock_hook.return_value.check_cluster_autoscaling_ability.return_value = True + mock_pod_hook.return_value.apply_from_yaml_file.side_effect = mock.MagicMock() + mock_depl_hook.return_value.get_deployment_status.return_value = READY_DEPLOYMENT + self.gke_op.execute(context=mock.MagicMock()) + assert "Kueue installed successfully!" in caplog.text + + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + @mock.patch(GKE_POD_HOOK_PATH) + def test_execute_autoscaled_cluster_check_error( + self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, caplog + ): + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + mock_hook.return_value.check_cluster_autoscaling_ability.return_value = True + mock_pod_hook.return_value.apply_from_yaml_file.side_effect = FailToCreateError("error") + self.gke_op.execute(context=mock.MagicMock()) + assert "Kueue is already enabled for the cluster" in caplog.text + + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + @mock.patch(GKE_POD_HOOK_PATH) + def test_execute_non_autoscaled_cluster_check_error( + self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, caplog + ): + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False + self.gke_op.execute(context=mock.MagicMock()) + assert ( + "Cluster doesn't have ability to autoscale, will not install Kueue inside. Aborting" + in caplog.text + ) + + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + def test_execute_with_impersonation_service_account(self, mock_hook, fetch_cluster_info_mock): + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.impersonation_chain = "test_account@example.com" + mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False + self.gke_op.execute(context=mock.MagicMock()) + mock_hook.return_value.get_cluster.assert_called_once() + + @mock.patch.dict(os.environ, {}) + @mock.patch( + "airflow.hooks.base.BaseHook.get_connections", + return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], + ) + @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") + @mock.patch(GKE_HOOK_PATH) + def test_execute_with_impersonation_service_chain_one_element( + self, mock_hook, fetch_cluster_info_mock, get_con_mock + ): + mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False + fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) + self.gke_op.impersonation_chain = ["test_account@example.com"] + self.gke_op.execute(context=mock.MagicMock()) + + mock_hook.return_value.get_cluster.assert_called_once() + + @pytest.mark.db_test + def test_default_gcp_conn_id(self): + gke_op = GKEStartKueueInsideClusterOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + kueue_version=KUEUE_VERSION, + impersonation_chain=IMPERSONATION_CHAIN, + use_internal_ip=USE_INTERNAL_API, + ) + gke_op._cluster_url = CLUSTER_URL + gke_op._ssl_ca_cert = SSL_CA_CERT + hook = gke_op.cluster_hook + + assert hook.gcp_conn_id == "google_cloud_default" + + @mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", + return_value=Connection(conn_id="test_conn"), + ) + def test_gcp_conn_id(self, get_con_mock): + gke_op = GKEStartKueueInsideClusterOperator( + project_id=TEST_GCP_PROJECT_ID, + location=PROJECT_LOCATION, + cluster_name=CLUSTER_NAME, + task_id=PROJECT_TASK_ID, + kueue_version=KUEUE_VERSION, + impersonation_chain=IMPERSONATION_CHAIN, + use_internal_ip=USE_INTERNAL_API, + gcp_conn_id="test_conn", + ) + gke_op._cluster_url = CLUSTER_URL + gke_op._ssl_ca_cert = SSL_CA_CERT + hook = gke_op.cluster_hook + + assert hook.gcp_conn_id == "test_conn" + + class TestGKEPodOperatorAsync: def setup_method(self): self.gke_op = GKEStartPodOperator( diff --git a/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py new file mode 100644 index 0000000000000..d87d0bd7d7470 --- /dev/null +++ b/tests/system/providers/google/cloud/kubernetes_engine/example_kubernetes_engine_kueue.py @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google Kubernetes Engine. +""" +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.kubernetes_engine import ( + GKECreateClusterOperator, + GKEDeleteClusterOperator, + GKEStartKueueInsideClusterOperator, +) + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "example_kubernetes_engine_kueue" +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") + +GCP_LOCATION = "europe-west3" +CLUSTER_NAME = f"cluster-name-test-kueue-{ENV_ID}".replace("_", "-") +CLUSTER = {"name": CLUSTER_NAME, "initial_node_count": 1, "autopilot": {"enabled": True}} + +with DAG( + DAG_ID, + schedule="@once", # Override to match your needs + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "kubernetes-engine", "kueue"], +) as dag: + create_cluster = GKECreateClusterOperator( + task_id="create_cluster", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + body=CLUSTER, + ) + + # [START howto_operator_gke_install_kueue] + add_kueue_cluster = GKEStartKueueInsideClusterOperator( + task_id="add_kueue_cluster", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + kueue_version="v0.5.1", + ) + # [END howto_operator_gke_install_kueue] + + delete_cluster = GKEDeleteClusterOperator( + task_id="delete_cluster", + name=CLUSTER_NAME, + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + ) + + create_cluster >> add_kueue_cluster >> delete_cluster + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)