diff --git a/providers/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/providers/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 11c8bd71c9d4f..6a377be3eb27c 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/providers/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -26,9 +26,11 @@ from typing import TYPE_CHECKING, Any import aiofiles +import requests import tenacity from asgiref.sync import sync_to_async -from kubernetes import client, config, watch +from kubernetes import client, config, utils, watch +from kubernetes.client.models import V1Deployment from kubernetes.config import ConfigException from kubernetes_asyncio import client as async_client, config as async_config from urllib3.exceptions import HTTPError @@ -47,7 +49,7 @@ if TYPE_CHECKING: from kubernetes.client import V1JobList - from kubernetes.client.models import V1Deployment, V1Job, V1Pod + from kubernetes.client.models import V1Job, V1Pod LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..." @@ -489,12 +491,9 @@ def get_deployment_status( :param name: Name of Deployment to retrieve :param namespace: Deployment namespace """ - try: - return self.apps_v1_client.read_namespaced_deployment_status( - name=name, namespace=namespace, pretty=True, **kwargs - ) - except Exception as exc: - raise exc + return self.apps_v1_client.read_namespaced_deployment_status( + name=name, namespace=namespace, pretty=True, **kwargs + ) @tenacity.retry( stop=tenacity.stop_after_attempt(3), @@ -644,6 +643,71 @@ def patch_namespaced_job(self, job_name: str, namespace: str, body: object) -> V body=body, ) + def apply_from_yaml_file( + self, + api_client: Any = None, + yaml_file: str | None = None, + yaml_objects: list[dict] | None = None, + verbose: bool = False, + namespace: str = "default", + ): + """ + Perform an action from a yaml file. + + :param api_client: A Kubernetes client application. + :param yaml_file: Contains the path to yaml file. + :param yaml_objects: List of YAML objects; used instead of reading the yaml_file. + :param verbose: If True, print confirmation from create action. Default is False. + :param namespace: Contains the namespace to create all resources inside. The namespace must + preexist otherwise the resource creation will fail. + """ + utils.create_from_yaml( + k8s_client=api_client or self.api_client, + yaml_objects=yaml_objects, + yaml_file=yaml_file, + verbose=verbose, + namespace=namespace or self.get_namespace(), + ) + + def check_kueue_deployment_running( + self, name: str, namespace: str, timeout: float = 300.0, polling_period_seconds: float = 2.0 + ) -> None: + _timeout = timeout + while _timeout > 0: + try: + deployment = self.get_deployment_status(name=name, namespace=namespace) + except Exception as e: + self.log.exception("Exception occurred while checking for Deployment status.") + raise e + + deployment_status = V1Deployment.to_dict(deployment)["status"] + replicas = deployment_status["replicas"] + ready_replicas = deployment_status["ready_replicas"] + unavailable_replicas = deployment_status["unavailable_replicas"] + if ( + replicas is not None + and ready_replicas is not None + and unavailable_replicas is None + and replicas == ready_replicas + ): + return + else: + self.log.info("Waiting until Deployment will be ready...") + sleep(polling_period_seconds) + + _timeout -= polling_period_seconds + + raise AirflowException("Deployment timed out") + + @staticmethod + def get_yaml_content_from_file(kueue_yaml_url) -> list[dict]: + """Download content of YAML file and separate it into several dictionaries.""" + response = requests.get(kueue_yaml_url, allow_redirects=True) + if response.status_code != 200: + raise AirflowException("Was not able to read the yaml file from given URL") + + return list(yaml.safe_load_all(response.text)) + def _get_bool(val) -> bool | None: """Convert val to bool if can be done with certainty; if we cannot infer intention we return None.""" diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/kueue.py b/providers/src/airflow/providers/cncf/kubernetes/operators/kueue.py new file mode 100644 index 0000000000000..01bd29f260cc9 --- /dev/null +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/kueue.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Manage a Kubernetes Kueue.""" + +from __future__ import annotations + +import json +import warnings +from collections.abc import Sequence +from functools import cached_property + +from kubernetes.utils import FailToCreateError + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook +from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator + + +class KubernetesInstallKueueOperator(BaseOperator): + """ + Installs a Kubernetes Kueue. + + :param kueue_version: The Kubernetes Kueue version to install. + :param kubernetes_conn_id: The :ref:`kubernetes connection id ` + for the Kubernetes cluster. + """ + + template_fields: Sequence[str] = ( + "kueue_version", + "kubernetes_conn_id", + ) + + def __init__(self, kueue_version: str, kubernetes_conn_id: str = "kubernetes_default", *args, **kwargs): + super().__init__(*args, **kwargs) + self.kubernetes_conn_id = kubernetes_conn_id + self.kueue_version = kueue_version + self._kueue_yaml_url = ( + f"https://github.com/kubernetes-sigs/kueue/releases/download/{self.kueue_version}/manifests.yaml" + ) + + @cached_property + def hook(self) -> KubernetesHook: + return KubernetesHook(conn_id=self.kubernetes_conn_id) + + def execute(self, context): + yaml_objects = self.hook.get_yaml_content_from_file(kueue_yaml_url=self._kueue_yaml_url) + try: + self.hook.apply_from_yaml_file(yaml_objects=yaml_objects) + except FailToCreateError as ex: + error_bodies = [json.loads(e.body) for e in ex.api_exceptions] + if next((e for e in error_bodies if e.get("reason") == "AlreadyExists"), None): + self.log.info("Kueue is already enabled for the cluster") + + if errors := [e for e in error_bodies if e.get("reason") != "AlreadyExists"]: + error_message = "\n".join(e.get("body") for e in errors) + raise AirflowException(error_message) + return + + self.hook.check_kueue_deployment_running(name="kueue-controller-manager", namespace="kueue-system") + self.log.info("Kueue installed successfully!") + + +class KubernetesStartKueueJobOperator(KubernetesJobOperator): + """ + Executes a Kubernetes Job in Kueue. + + :param queue_name: The name of the Queue in the cluster + """ + + template_fields = tuple({"queue_name"} | set(KubernetesJobOperator.template_fields)) + + def __init__(self, queue_name: str, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.queue_name = queue_name + + if self.suspend is False: + raise AirflowException( + "The `suspend` parameter can't be False. If you want to use Kueue for running Job" + " in a Kubernetes cluster, set the `suspend` parameter to True.", + ) + elif self.suspend is None: + warnings.warn( + f"You have not set parameter `suspend` in class {self.__class__.__name__}. " + "For running a Job in Kueue the `suspend` parameter should set to True.", + UserWarning, + stacklevel=2, + ) + self.suspend = True + self.labels.update({"kueue.x-k8s.io/queue-name": self.queue_name}) + self.annotations.update({"kueue.x-k8s.io/queue-name": self.queue_name}) diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index cbafc72f3455c..c7e0084fc1d83 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -396,7 +396,9 @@ def __init__( self.remote_pod: k8s.V1Pod | None = None self.log_pod_spec_on_failure = log_pod_spec_on_failure self.on_finish_action = OnFinishAction(on_finish_action) - self.is_delete_operator_pod = self.on_finish_action == OnFinishAction.DELETE_POD + # The `is_delete_operator_pod` parameter should have been removed in provider version 10.0.0. + # TODO: remove it from here and from the operator's parameters list when the next major version bumped + self._is_delete_operator_pod = self.on_finish_action == OnFinishAction.DELETE_POD self.termination_message_policy = termination_message_policy self.active_deadline_seconds = active_deadline_seconds self.logging_interval = logging_interval diff --git a/providers/src/airflow/providers/google/cloud/hooks/kubernetes_engine.py b/providers/src/airflow/providers/google/cloud/hooks/kubernetes_engine.py index 23eab9dd3cd44..3b791c5359d5c 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/kubernetes_engine.py +++ b/providers/src/airflow/providers/google/cloud/hooks/kubernetes_engine.py @@ -23,7 +23,7 @@ import json import time from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -33,8 +33,7 @@ from google.cloud import exceptions # type: ignore[attr-defined] from google.cloud.container_v1 import ClusterManagerAsyncClient, ClusterManagerClient from google.cloud.container_v1.types import Cluster, Operation -from kubernetes import client, utils -from kubernetes.client.models import V1Deployment +from kubernetes import client from kubernetes_asyncio import client as async_client from kubernetes_asyncio.config.kube_config import FileOrData @@ -434,38 +433,9 @@ def get_conn(self) -> client.ApiClient: enable_tcp_keepalive=self.enable_tcp_keepalive, ).get_conn() - def check_kueue_deployment_running(self, name, namespace): - timeout = 300 - polling_period_seconds = 2 - - while timeout is None or timeout > 0: - try: - deployment = self.get_deployment_status(name=name, namespace=namespace) - deployment_status = V1Deployment.to_dict(deployment)["status"] - replicas = deployment_status["replicas"] - ready_replicas = deployment_status["ready_replicas"] - unavailable_replicas = deployment_status["unavailable_replicas"] - if ( - replicas is not None - and ready_replicas is not None - and unavailable_replicas is None - and replicas == ready_replicas - ): - return - else: - self.log.info("Waiting until Deployment will be ready...") - time.sleep(polling_period_seconds) - except Exception as e: - self.log.exception("Exception occurred while checking for Deployment status.") - raise e - - if timeout is not None: - timeout -= polling_period_seconds - - raise AirflowException("Deployment timed out") - def apply_from_yaml_file( self, + api_client: Any = None, yaml_file: str | None = None, yaml_objects: list[dict] | None = None, verbose: bool = False, @@ -474,18 +444,17 @@ def apply_from_yaml_file( """ Perform an action from a yaml file. + :param api_client: A Kubernetes client application. :param yaml_file: Contains the path to yaml file. :param yaml_objects: List of YAML objects; used instead of reading the yaml_file. :param verbose: If True, print confirmation from create action. Default is False. :param namespace: Contains the namespace to create all resources inside. The namespace must preexist otherwise the resource creation will fail. """ - k8s_client = self.get_conn() - - utils.create_from_yaml( - k8s_client=k8s_client, - yaml_objects=yaml_objects, + super().apply_from_yaml_file( + api_client=api_client or self.get_conn(), yaml_file=yaml_file, + yaml_objects=yaml_objects, verbose=verbose, namespace=namespace, ) diff --git a/providers/src/airflow/providers/google/cloud/operators/kubernetes_engine.py b/providers/src/airflow/providers/google/cloud/operators/kubernetes_engine.py index 82c3dd0ffefd6..759a0190ab210 100644 --- a/providers/src/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/providers/src/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -24,17 +24,17 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -import requests -import yaml from google.api_core.exceptions import AlreadyExists -from google.cloud.container_v1.types import Cluster from kubernetes.client import V1JobList, models as k8s -from kubernetes.utils.create_from_yaml import FailToCreateError from packaging.version import parse as parse_version from airflow.configuration import conf from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator +from airflow.providers.cncf.kubernetes.operators.kueue import ( + KubernetesInstallKueueOperator, + KubernetesStartKueueJobOperator, +) from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.operators.resource import ( KubernetesCreateResourceOperator, @@ -57,6 +57,7 @@ GKEOperationTrigger, GKEStartPodTrigger, ) +from airflow.providers.google.common.deprecated import deprecated from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.providers_manager import ProvidersManager from airflow.utils.timezone import utcnow @@ -72,7 +73,8 @@ ) if TYPE_CHECKING: - from kubernetes.client.models import V1Job, V1Pod + from google.cloud.container_v1.types import Cluster + from kubernetes.client.models import V1Job from pendulum import DateTime from airflow.utils.context import Context @@ -92,17 +94,17 @@ class GKEClusterAuthDetails: def __init__( self, - cluster_name, - project_id, - use_internal_ip, - cluster_hook, + cluster_name: str, + project_id: str, + use_internal_ip: bool, + cluster_hook: GKEHook, ): self.cluster_name = cluster_name self.project_id = project_id self.use_internal_ip = use_internal_ip self.cluster_hook = cluster_hook - self._cluster_url = None - self._ssl_ca_cert = None + self._cluster_url: str + self._ssl_ca_cert: str def fetch_cluster_info(self) -> tuple[str, str]: """Fetch cluster info for connecting to it.""" @@ -119,11 +121,100 @@ def fetch_cluster_info(self) -> tuple[str, str]: return self._cluster_url, self._ssl_ca_cert -class GKEDeleteClusterOperator(GoogleCloudBaseOperator): +class GKEBaseOperator(GoogleCloudBaseOperator): + """ + Base class for all GKE operators. + + :param location: The name of the Google Kubernetes Engine zone or region in which the + cluster resides, e.g. 'us-central1-a' + :param cluster_name: The name of the Google Kubernetes Engine cluster. + :param use_internal_ip: Use the internal IP address as the endpoint. + :param project_id: The Google Developers Console project id + :param gcp_conn_id: The Google cloud connection id to use. This allows for + users to specify a service account. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + enable_tcp_keepalive = False + + template_fields: Sequence[str] = ( + "location", + "cluster_name", + "use_internal_ip", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + location: str, + cluster_name: str, + use_internal_ip: bool = False, + project_id: str = PROVIDE_PROJECT_ID, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.project_id = project_id + self.location = location + self.cluster_name = cluster_name + self.gcp_conn_id = gcp_conn_id + self.use_internal_ip = use_internal_ip + self.impersonation_chain = impersonation_chain + + @cached_property + def cluster_hook(self) -> GKEHook: + return GKEHook( + gcp_conn_id=self.gcp_conn_id, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + @cached_property + def hook(self) -> GKEKubernetesHook: + return GKEKubernetesHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + cluster_url=self.cluster_url, + ssl_ca_cert=self.ssl_ca_cert, + enable_tcp_keepalive=self.enable_tcp_keepalive, + ) + + @cached_property + def cluster_info(self) -> tuple[str, str]: + """Fetch cluster info for connecting to it.""" + auth_details = GKEClusterAuthDetails( + cluster_name=self.cluster_name, + project_id=self.project_id, + use_internal_ip=self.use_internal_ip, + cluster_hook=self.cluster_hook, + ) + return auth_details.fetch_cluster_info() + + @property + def cluster_url(self) -> str: + return self.cluster_info[0] + + @property + def ssl_ca_cert(self) -> str: + return self.cluster_info[1] + + +class GKEDeleteClusterOperator(GKEBaseOperator): """ Deletes the cluster, including the Kubernetes endpoint and all worker nodes. - To delete a certain cluster, you must specify the ``project_id``, the ``name`` + To delete a certain cluster, you must specify the ``project_id``, the ``cluster_name`` of the cluster, the ``location`` that the cluster is in, and the ``task_id``. **Operator Creation**: :: @@ -132,7 +223,7 @@ class GKEDeleteClusterOperator(GoogleCloudBaseOperator): task_id='cluster_delete', project_id='my-project', location='cluster-location' - name='cluster-name') + cluster_name='cluster-name') .. seealso:: For more detail about deleting clusters have a look at the reference: @@ -141,72 +232,63 @@ class GKEDeleteClusterOperator(GoogleCloudBaseOperator): .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:GKEDeleteClusterOperator` - - :param project_id: The Google Developers Console [project ID or project number] - :param name: The name of the resource to delete, in this case cluster name - :param location: The name of the Google Kubernetes Engine zone or region in which the cluster - resides. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param name: (Deprecated) The name of the resource to delete, in this case cluster name :param api_version: The api version to use - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). :param deferrable: Run operator in the deferrable mode. :param poll_interval: Interval size which defines how often operation status is checked. """ - template_fields: Sequence[str] = ( - "project_id", - "gcp_conn_id", - "name", - "location", - "api_version", - "impersonation_chain", + template_fields: Sequence[str] = tuple( + {"api_version", "deferrable", "poll_interval"} | set(GKEBaseOperator.template_fields) ) def __init__( self, - *, - name: str, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - gcp_conn_id: str = "google_cloud_default", + name: str | None = None, api_version: str = "v2", - impersonation_chain: str | Sequence[str] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: int = 10, + *args, **kwargs, ) -> None: - super().__init__(**kwargs) + if "cluster_name" not in kwargs: + kwargs["cluster_name"] = name + + super().__init__(*args, **kwargs) - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.location = location self.api_version = api_version - self.name = name - self.impersonation_chain = impersonation_chain + self._name = name self.deferrable = deferrable self.poll_interval = poll_interval self._check_input() - self._hook: GKEHook | None = None + @property + @deprecated( + planned_removal_date="May 01, 2025", + use_instead="cluster_name_get", + category=AirflowProviderDeprecationWarning, + ) + def name(self) -> str | None: + return self._name + + @name.setter + @deprecated( + planned_removal_date="May 01, 2025", + use_instead="cluster_name_set", + category=AirflowProviderDeprecationWarning, + ) + def name(self, name: str) -> None: + self._name = name def _check_input(self) -> None: - if not all([self.project_id, self.name, self.location]): - self.log.error("One of (project_id, name, location) is missing or incorrect") + if not all([self.project_id, self.cluster_name, self.location]): + self.log.error("One of (project_id, cluster_name, location) is missing or incorrect") raise AirflowException("Operator has incorrect or missing input.") def execute(self, context: Context) -> str | None: - hook = self._get_hook() - wait_to_complete = not self.deferrable - operation = hook.delete_cluster( - name=self.name, + operation = self.cluster_hook.delete_cluster( + name=self.cluster_name, project_id=self.project_id, wait_to_complete=wait_to_complete, ) @@ -236,23 +318,13 @@ def execute_complete(self, context: Context, event: dict) -> str: raise AirflowException(message) self.log.info(message) - operation = self._get_hook().get_operation( + operation = self.cluster_hook.get_operation( operation_name=event["operation_name"], ) return operation.self_link - def _get_hook(self) -> GKEHook: - if self._hook is None: - self._hook = GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - return self._hook - -class GKECreateClusterOperator(GoogleCloudBaseOperator): +class GKECreateClusterOperator(GKEBaseOperator): """ Create a Google Kubernetes Engine Cluster of specified dimensions and wait until the cluster is created. @@ -284,61 +356,35 @@ class GKECreateClusterOperator(GoogleCloudBaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:GKECreateClusterOperator` - :param project_id: The Google Developers Console [project ID or project number] - :param location: The name of the Google Kubernetes Engine zone or region in which the cluster - resides. :param body: The Cluster definition to create, can be protobuf or python dict, if dict it must match protobuf message Cluster - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param api_version: The api version to use - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). :param deferrable: Run operator in the deferrable mode. :param poll_interval: Interval size which defines how often operation status is checked. """ - template_fields: Sequence[str] = ( - "project_id", - "gcp_conn_id", - "location", - "api_version", - "body", - "impersonation_chain", + template_fields: Sequence[str] = tuple( + {"body", "api_version", "deferrable", "poll_interval"} | set(GKEBaseOperator.template_fields) ) operator_extra_links = (KubernetesEngineClusterLink(),) def __init__( self, - *, - location: str, body: dict | Cluster, - project_id: str = PROVIDE_PROJECT_ID, - gcp_conn_id: str = "google_cloud_default", api_version: str = "v2", - impersonation_chain: str | Sequence[str] | None = None, - poll_interval: int = 10, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 10, + *args, **kwargs, ) -> None: - super().__init__(**kwargs) - - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.location = location self.api_version = api_version self.body = body - self.impersonation_chain = impersonation_chain self.poll_interval = poll_interval self.deferrable = deferrable self._validate_input() - self._hook: GKEHook | None = None + kwargs["cluster_name"] = body["name"] if isinstance(body, dict) else body.name + super().__init__(*args, **kwargs) def _validate_input(self) -> None: """Primary validation of the input body.""" @@ -404,36 +450,32 @@ def _alert_deprecated_body_fields(self) -> None: ) def execute(self, context: Context) -> str: - hook = self._get_hook() + KubernetesEngineClusterLink.persist(context=context, task_instance=self, cluster=self.body) + try: - wait_to_complete = not self.deferrable - operation = hook.create_cluster( + operation = self.cluster_hook.create_cluster( cluster=self.body, project_id=self.project_id, - wait_to_complete=wait_to_complete, + wait_to_complete=not self.deferrable, ) - - KubernetesEngineClusterLink.persist(context=context, task_instance=self, cluster=self.body) - - if self.deferrable: - self.defer( - trigger=GKEOperationTrigger( - operation_name=operation.name, - project_id=self.project_id, - location=self.location, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - poll_interval=self.poll_interval, - ), - method_name="execute_complete", - ) - - return operation.target_link - except AlreadyExists as error: self.log.info("Assuming Success: %s", error.message) - name = self.body.name if isinstance(self.body, Cluster) else self.body["name"] - return hook.get_cluster(name=name, project_id=self.project_id).self_link + return self.cluster_hook.get_cluster(name=self.cluster_name, project_id=self.project_id).self_link + + if self.deferrable: + self.defer( + trigger=GKEOperationTrigger( + operation_name=operation.name, + project_id=self.project_id, + location=self.location, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + poll_interval=self.poll_interval, + ), + method_name="execute_complete", + ) + + return operation.target_link def execute_complete(self, context: Context, event: dict) -> str: status = event["status"] @@ -444,23 +486,13 @@ def execute_complete(self, context: Context, event: dict) -> str: raise AirflowException(message) self.log.info(message) - operation = self._get_hook().get_operation( + operation = self.cluster_hook.get_operation( operation_name=event["operation_name"], ) return operation.target_link - def _get_hook(self) -> GKEHook: - if self._hook is None: - self._hook = GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - return self._hook - -class GKEStartKueueInsideClusterOperator(GoogleCloudBaseOperator): +class GKEStartKueueInsideClusterOperator(GKEBaseOperator, KubernetesInstallKueueOperator): """ Installs Kueue of specific version inside Cluster. @@ -471,137 +503,27 @@ class GKEStartKueueInsideClusterOperator(GoogleCloudBaseOperator): .. seealso:: For more details about Kueue have a look at the reference: https://kueue.sigs.k8s.io/docs/overview/ - - :param project_id: The Google Developers Console [project ID or project number]. - :param location: The name of the Google Kubernetes Engine zone or region in which the cluster resides. - :param cluster_name: The Cluster name in which to install Kueue. - :param kueue_version: Version of Kueue to install. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ( - "project_id", - "location", - "kueue_version", - "cluster_name", - "gcp_conn_id", - "impersonation_chain", + enable_tcp_keepalive = True + template_fields = tuple( + set(GKEBaseOperator.template_fields) | set(KubernetesInstallKueueOperator.template_fields) ) operator_extra_links = (KubernetesEngineClusterLink(),) - def __init__( - self, - *, - location: str, - cluster_name: str, - kueue_version: str, - use_internal_ip: bool = False, - project_id: str = PROVIDE_PROJECT_ID, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.cluster_name = cluster_name - self.kueue_version = kueue_version - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - self.use_internal_ip = use_internal_ip - self._kueue_yaml_url = ( - f"https://github.com/kubernetes-sigs/kueue/releases/download/{self.kueue_version}/manifests.yaml" - ) - - @cached_property - def cluster_hook(self) -> GKEHook: - return GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def deployment_hook(self) -> GKEKubernetesHook: - if self._cluster_url is None or self._ssl_ca_cert is None: - raise AttributeError( - "Cluster url and ssl_ca_cert should be defined before using self.deployment_hook method. " - "Try to use self.get_kube_creds method", - ) - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - ) - - @cached_property - def pod_hook(self) -> GKEKubernetesHook: - if self._cluster_url is None or self._ssl_ca_cert is None: - raise AttributeError( - "Cluster url and ssl_ca_cert should be defined before using self.pod_hook method. " - "Try to use self.get_kube_creds method", - ) - - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - enable_tcp_keepalive=True, - ) - - @staticmethod - def _get_yaml_content_from_file(kueue_yaml_url) -> list[dict]: - """Download content of YAML file and separate it into several dictionaries.""" - response = requests.get(kueue_yaml_url, allow_redirects=True) - if response.status_code != 200: - raise AirflowException("Was not able to read the yaml file from given URL") - - return list(yaml.safe_load_all(response.text)) - def execute(self, context: Context): - self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( - cluster_name=self.cluster_name, - project_id=self.project_id, - use_internal_ip=self.use_internal_ip, - cluster_hook=self.cluster_hook, - ).fetch_cluster_info() - - cluster = self.cluster_hook.get_cluster( - name=self.cluster_name, - project_id=self.project_id, - ) + cluster = self.cluster_hook.get_cluster(name=self.cluster_name, project_id=self.project_id) KubernetesEngineClusterLink.persist(context=context, task_instance=self, cluster=cluster) - yaml_objects = self._get_yaml_content_from_file(kueue_yaml_url=self._kueue_yaml_url) - if self.cluster_hook.check_cluster_autoscaling_ability(cluster=cluster): - try: - self.pod_hook.apply_from_yaml_file(yaml_objects=yaml_objects) - - self.deployment_hook.check_kueue_deployment_running( - name="kueue-controller-manager", namespace="kueue-system" - ) - - self.log.info("Kueue installed successfully!") - except FailToCreateError: - self.log.info("Kueue is already enabled for the cluster") + super().execute(context) else: self.log.info( "Cluster doesn't have ability to autoscale, will not install Kueue inside. Aborting" ) -class GKEStartPodOperator(KubernetesPodOperator): +class GKEStartPodOperator(GKEBaseOperator, KubernetesPodOperator): """ Executes a task in a Kubernetes pod in the specified Google Kubernetes Engine cluster. @@ -620,151 +542,98 @@ class GKEStartPodOperator(KubernetesPodOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:GKEStartPodOperator` - :param location: The name of the Google Kubernetes Engine zone or region in which the - cluster resides, e.g. 'us-central1-a' - :param cluster_name: The name of the Google Kubernetes Engine cluster the pod - should be spawned in - :param use_internal_ip: Use the internal IP address as the endpoint. - :param project_id: The Google Developers Console project id - :param gcp_conn_id: The Google cloud connection id to use. This allows for - users to specify a service account. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - :param regional: The location param is region name. - :param deferrable: Run operator in the deferrable mode. + :param regional: (Deprecated) The location param is region name. :param on_finish_action: What to do when the pod reaches its final state, or the execution is interrupted. If "delete_pod", the pod will be deleted regardless its state; if "delete_succeeded_pod", only succeeded pod will be deleted. You can set to "keep_pod" to keep the pod. Current default is `keep_pod`, but this will be changed in the next major release of this provider. - :param is_delete_operator_pod: What to do when the pod reaches its final + :param is_delete_operator_pod: (Deprecated) What to do when the pod reaches its final state, or the execution is interrupted. If True, delete the pod; if False, leave the pod. Current default is False, but this will be changed in the next major release of this provider. Deprecated - use `on_finish_action` instead. + :param deferrable: Run operator in the deferrable mode. """ template_fields: Sequence[str] = tuple( - {"project_id", "location", "cluster_name"} | set(KubernetesPodOperator.template_fields) + {"on_finish_action", "deferrable"} + | (set(KubernetesPodOperator.template_fields) - {"is_delete_operator_pod", "regional"}) + | set(GKEBaseOperator.template_fields) ) operator_extra_links = (KubernetesEnginePodLink(),) def __init__( self, - *, - location: str, - cluster_name: str, - use_internal_ip: bool = False, - project_id: str = PROVIDE_PROJECT_ID, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, regional: bool | None = None, on_finish_action: str | None = None, is_delete_operator_pod: bool | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + *args, **kwargs, ) -> None: if is_delete_operator_pod is not None: - warnings.warn( - "`is_delete_operator_pod` parameter is deprecated, please use `on_finish_action`", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) kwargs["on_finish_action"] = ( OnFinishAction.DELETE_POD if is_delete_operator_pod else OnFinishAction.KEEP_POD ) + elif on_finish_action is not None: + kwargs["on_finish_action"] = OnFinishAction(on_finish_action) else: - if on_finish_action is not None: - kwargs["on_finish_action"] = OnFinishAction(on_finish_action) - else: - warnings.warn( - f"You have not set parameter `on_finish_action` in class {self.__class__.__name__}. " - "Currently the default for this parameter is `keep_pod` but in a future release" - " the default will be changed to `delete_pod`. To ensure pods are not deleted in" - " the future you will need to set `on_finish_action=keep_pod` explicitly.", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) - kwargs["on_finish_action"] = OnFinishAction.KEEP_POD - - if regional is not None: warnings.warn( - f"You have set parameter regional in class {self.__class__.__name__}. " - "In current implementation of the operator the parameter is not used and will " - "be deleted in future.", + f"You have not set parameter `on_finish_action` in class {self.__class__.__name__}. " + "Currently the default for this parameter is `keep_pod` but in a future release" + " the default will be changed to `delete_pod`. To ensure pods are not deleted in" + " the future you will need to set `on_finish_action=keep_pod` explicitly.", AirflowProviderDeprecationWarning, stacklevel=2, ) + kwargs["on_finish_action"] = OnFinishAction.KEEP_POD - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.cluster_name = cluster_name - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - self.use_internal_ip = use_internal_ip - - self.pod: V1Pod | None = None - self._ssl_ca_cert: str | None = None - self._cluster_url: str | None = None + super().__init__(*args, **kwargs) + self._regional = regional + if is_delete_operator_pod is not None: + self.is_delete_operator_pod = is_delete_operator_pod + self.deferrable = deferrable - if self.gcp_conn_id is None: - raise AirflowException( - "The gcp_conn_id parameter has become required. If you want to use Application Default " - "Credentials (ADC) strategy for authorization, create an empty connection " - "called `google_cloud_default`.", - ) # There is no need to manage the kube_config file, as it will be generated automatically. # All Kubernetes parameters (except config_file) are also valid for the GKEStartPodOperator. if self.config_file: raise AirflowException("config_file is not an allowed parameter for the GKEStartPodOperator.") - @cached_property - def cluster_hook(self) -> GKEHook: - return GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def hook(self) -> GKEKubernetesHook: - if self._cluster_url is None or self._ssl_ca_cert is None: - raise AttributeError( - "Cluster url and ssl_ca_cert should be defined before using self.hook method. " - "Try to use self.get_kube_creds method", - ) - - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - impersonation_chain=self.impersonation_chain, - enable_tcp_keepalive=True, - ) - - def execute(self, context: Context): - """Execute process of creating pod and executing provided command inside it.""" - self.fetch_cluster_info() - return super().execute(context) - - def fetch_cluster_info(self) -> tuple[str, str | None]: - """Fetch cluster info for connecting to it.""" - cluster = self.cluster_hook.get_cluster( - name=self.cluster_name, - project_id=self.project_id, - ) - - if not self.use_internal_ip: - self._cluster_url = f"https://{cluster.endpoint}" - else: - self._cluster_url = f"https://{cluster.private_cluster_config.private_endpoint}" - self._ssl_ca_cert = cluster.master_auth.cluster_ca_certificate - return self._cluster_url, self._ssl_ca_cert + @property + @deprecated( + planned_removal_date="May 01, 2025", + use_instead="on_finish_action", + category=AirflowProviderDeprecationWarning, + ) + def is_delete_operator_pod(self) -> bool | None: + return self._is_delete_operator_pod + + @is_delete_operator_pod.setter + @deprecated( + planned_removal_date="May 01, 2025", + use_instead="on_finish_action", + category=AirflowProviderDeprecationWarning, + ) + def is_delete_operator_pod(self, is_delete_operator_pod) -> None: + self._is_delete_operator_pod = is_delete_operator_pod + + @property + @deprecated( + planned_removal_date="May 01, 2025", + reason="The parameter is not in actual use.", + category=AirflowProviderDeprecationWarning, + ) + def regional(self) -> bool | None: + return self._regional + + @regional.setter + @deprecated( + planned_removal_date="May 01, 2025", + reason="The parameter is not in actual use.", + category=AirflowProviderDeprecationWarning, + ) + def regional(self, regional) -> None: + self._regional = regional def invoke_defer_method(self, last_log_time: DateTime | None = None): """Redefine triggers which are being used in child classes.""" @@ -774,8 +643,8 @@ def invoke_defer_method(self, last_log_time: DateTime | None = None): pod_name=self.pod.metadata.name, # type: ignore[union-attr] pod_namespace=self.pod.metadata.namespace, # type: ignore[union-attr] trigger_start_time=trigger_start_time, - cluster_url=self._cluster_url, # type: ignore[arg-type] - ssl_ca_cert=self._ssl_ca_cert, # type: ignore[arg-type] + cluster_url=self.cluster_url, + ssl_ca_cert=self.ssl_ca_cert, get_logs=self.get_logs, startup_timeout=self.startup_timeout_seconds, cluster_context=self.cluster_context, @@ -788,19 +657,11 @@ def invoke_defer_method(self, last_log_time: DateTime | None = None): logging_interval=self.logging_interval, last_log_time=last_log_time, ), - method_name="execute_complete", - kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": self._ssl_ca_cert}, + method_name="trigger_reentry", ) - def execute_complete(self, context: Context, event: dict, **kwargs): - # It is required for hook to be initialized - self._cluster_url = kwargs["cluster_url"] - self._ssl_ca_cert = kwargs["ssl_ca_cert"] - - return super().trigger_reentry(context, event) - -class GKEStartJobOperator(KubernetesJobOperator): +class GKEStartJobOperator(GKEBaseOperator, KubernetesJobOperator): """ Executes a Kubernetes job in the specified Google Kubernetes Engine cluster. @@ -819,92 +680,34 @@ class GKEStartJobOperator(KubernetesJobOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:GKEStartJobOperator` - :param location: The name of the Google Kubernetes Engine zone or region in which the - cluster resides, e.g. 'us-central1-a' - :param cluster_name: The name of the Google Kubernetes Engine cluster - :param use_internal_ip: Use the internal IP address as the endpoint. - :param project_id: The Google Developers Console project id - :param gcp_conn_id: The Google cloud connection id to use. This allows for - users to specify a service account. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - :param location: The location param is region name. :param deferrable: Run operator in the deferrable mode. :param poll_interval: (Deferrable mode only) polling period in seconds to check for the status of job. """ template_fields: Sequence[str] = tuple( - {"project_id", "location", "cluster_name"} | set(KubernetesJobOperator.template_fields) + {"deferrable", "poll_interval"} + | set(GKEBaseOperator.template_fields) + | set(KubernetesJobOperator.template_fields) ) operator_extra_links = (KubernetesEngineJobLink(),) def __init__( self, - *, - location: str, - cluster_name: str, - use_internal_ip: bool = False, - project_id: str = PROVIDE_PROJECT_ID, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), job_poll_interval: float = 10.0, + *args, **kwargs, ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.cluster_name = cluster_name - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - self.use_internal_ip = use_internal_ip + super().__init__(*args, **kwargs) self.deferrable = deferrable self.job_poll_interval = job_poll_interval - self.job: V1Job | None = None - self._ssl_ca_cert: str | None = None - self._cluster_url: str | None = None - - if self.gcp_conn_id is None: - raise AirflowException( - "The gcp_conn_id parameter has become required. If you want to use Application Default " - "Credentials (ADC) strategy for authorization, create an empty connection " - "called `google_cloud_default`.", - ) # There is no need to manage the kube_config file, as it will be generated automatically. # All Kubernetes parameters (except config_file) are also valid for the GKEStartJobOperator. if self.config_file: raise AirflowException("config_file is not an allowed parameter for the GKEStartJobOperator.") - @cached_property - def cluster_hook(self) -> GKEHook: - return GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def hook(self) -> GKEKubernetesHook: - if self._cluster_url is None or self._ssl_ca_cert is None: - raise AttributeError( - "Cluster url and ssl_ca_cert should be defined before using self.hook method. " - "Try to use self.get_kube_creds method", - ) - - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - ) - def execute(self, context: Context): """Execute process of creating Job.""" if self.deferrable: @@ -918,21 +721,13 @@ def execute(self, context: Context): f"package {kubernetes_provider_name}=={kubernetes_provider_version} which doesn't " f"support this feature. Please upgrade it to version higher than {min_version}." ) - - self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( - cluster_name=self.cluster_name, - project_id=self.project_id, - use_internal_ip=self.use_internal_ip, - cluster_hook=self.cluster_hook, - ).fetch_cluster_info() - return super().execute(context) def execute_deferrable(self): self.defer( trigger=GKEJobTrigger( - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, + cluster_url=self.cluster_url, + ssl_ca_cert=self.ssl_ca_cert, job_name=self.job.metadata.name, # type: ignore[union-attr] job_namespace=self.job.metadata.namespace, # type: ignore[union-attr] pod_name=self.pod.metadata.name, # type: ignore[union-attr] @@ -945,18 +740,10 @@ def execute_deferrable(self): do_xcom_push=self.do_xcom_push, ), method_name="execute_complete", - kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": self._ssl_ca_cert}, ) - def execute_complete(self, context: Context, event: dict, **kwargs): - # It is required for hook to be initialized - self._cluster_url = kwargs["cluster_url"] - self._ssl_ca_cert = kwargs["ssl_ca_cert"] - - return super().execute_complete(context, event) - -class GKEDescribeJobOperator(GoogleCloudBaseOperator): +class GKEDescribeJobOperator(GKEBaseOperator): """ Retrieve information about Job by given name. @@ -965,84 +752,24 @@ class GKEDescribeJobOperator(GoogleCloudBaseOperator): :ref:`howto/operator:GKEDescribeJobOperator` :param job_name: The name of the Job to delete - :param project_id: The Google Developers Console project id. - :param location: The name of the Google Kubernetes Engine zone or region in which the cluster - resides. - :param cluster_name: The name of the Google Kubernetes Engine cluster. :param namespace: The name of the Google Kubernetes Engine namespace. - :param use_internal_ip: Use the internal IP address as the endpoint. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ( - "project_id", - "gcp_conn_id", - "job_name", - "namespace", - "cluster_name", - "location", - "impersonation_chain", - ) + template_fields: Sequence[str] = tuple({"job_name", "namespace"} | set(GKEBaseOperator.template_fields)) operator_extra_links = (KubernetesEngineJobLink(),) def __init__( self, - *, job_name: str, - location: str, namespace: str, - cluster_name: str, - project_id: str = PROVIDE_PROJECT_ID, - use_internal_ip: bool = False, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, + *args, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(*args, **kwargs) - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.location = location self.job_name = job_name self.namespace = namespace - self.cluster_name = cluster_name - self.use_internal_ip = use_internal_ip - self.impersonation_chain = impersonation_chain - self.job: V1Job | None = None - self._ssl_ca_cert: str - self._cluster_url: str - - @cached_property - def cluster_hook(self) -> GKEHook: - return GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def hook(self) -> GKEKubernetesHook: - self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( - cluster_name=self.cluster_name, - project_id=self.project_id, - use_internal_ip=self.use_internal_ip, - cluster_hook=self.cluster_hook, - ).fetch_cluster_info() - - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - ) def execute(self, context: Context) -> None: self.job = self.hook.get_job(job_name=self.job_name, namespace=self.namespace) @@ -1056,7 +783,7 @@ def execute(self, context: Context) -> None: return None -class GKEListJobsOperator(GoogleCloudBaseOperator): +class GKEListJobsOperator(GKEBaseOperator): """ Retrieve list of Jobs. @@ -1067,83 +794,24 @@ class GKEListJobsOperator(GoogleCloudBaseOperator): For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:GKEListJobsOperator` - :param project_id: The Google Developers Console project id. - :param location: The name of the Google Kubernetes Engine zone or region in which the cluster - resides. - :param cluster_name: The name of the Google Kubernetes Engine cluster. :param namespace: The name of the Google Kubernetes Engine namespace. - :param use_internal_ip: Use the internal IP address as the endpoint. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param do_xcom_push: If set to True the result list of Jobs will be pushed to the task result. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ( - "project_id", - "gcp_conn_id", - "namespace", - "cluster_name", - "location", - "impersonation_chain", - ) + template_fields: Sequence[str] = tuple({"namespace"} | set(GKEBaseOperator.template_fields)) operator_extra_links = (KubernetesEngineWorkloadsLink(),) def __init__( self, - *, - location: str, - cluster_name: str, namespace: str | None = None, - project_id: str = PROVIDE_PROJECT_ID, - use_internal_ip: bool = False, do_xcom_push: bool = True, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, + *args, **kwargs, ) -> None: - super().__init__(**kwargs) + super().__init__(*args, **kwargs) - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.location = location self.namespace = namespace - self.cluster_name = cluster_name - self.use_internal_ip = use_internal_ip self.do_xcom_push = do_xcom_push - self.impersonation_chain = impersonation_chain - - self._ssl_ca_cert: str - self._cluster_url: str - - @cached_property - def cluster_hook(self) -> GKEHook: - return GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def hook(self) -> GKEKubernetesHook: - self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( - cluster_name=self.cluster_name, - project_id=self.project_id, - use_internal_ip=self.use_internal_ip, - cluster_hook=self.cluster_hook, - ).fetch_cluster_info() - - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - ) def execute(self, context: Context) -> dict: if self.namespace: @@ -1159,7 +827,7 @@ def execute(self, context: Context) -> dict: return V1JobList.to_dict(jobs) -class GKECreateCustomResourceOperator(KubernetesCreateResourceOperator): +class GKECreateCustomResourceOperator(GKEBaseOperator, KubernetesCreateResourceOperator): """ Create a resource in the specified Google Kubernetes Engine cluster. @@ -1173,49 +841,14 @@ class GKECreateCustomResourceOperator(KubernetesCreateResourceOperator): .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:GKECreateCustomResourceOperator` - - :param location: The name of the Google Kubernetes Engine zone or region in which the - cluster resides, e.g. 'us-central1-a' - :param cluster_name: The name of the Google Kubernetes Engine cluster. - :param use_internal_ip: Use the internal IP address as the endpoint. - :param project_id: The Google Developers Console project id - :param gcp_conn_id: The Google cloud connection id to use. This allows for - users to specify a service account. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). """ template_fields: Sequence[str] = tuple( - {"project_id", "location", "cluster_name"} | set(KubernetesCreateResourceOperator.template_fields) + set(GKEBaseOperator.template_fields) | set(KubernetesCreateResourceOperator.template_fields) ) - def __init__( - self, - *, - location: str, - cluster_name: str, - use_internal_ip: bool = False, - project_id: str = PROVIDE_PROJECT_ID, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.cluster_name = cluster_name - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - self.use_internal_ip = use_internal_ip - - self._ssl_ca_cert: str | None = None - self._cluster_url: str | None = None + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) if self.gcp_conn_id is None: raise AirflowException( @@ -1224,44 +857,14 @@ def __init__( "called `google_cloud_default`.", ) # There is no need to manage the kube_config file, as it will be generated automatically. - # All Kubernetes parameters (except config_file) are also valid for the GKEStartPodOperator. + # All Kubernetes parameters (except config_file) are also valid for the GKECreateCustomResourceOperator. if self.config_file: - raise AirflowException("config_file is not an allowed parameter for the GKEStartPodOperator.") + raise AirflowException( + "config_file is not an allowed parameter for the GKECreateCustomResourceOperator." + ) - @cached_property - def cluster_hook(self) -> GKEHook: - return GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - @cached_property - def hook(self) -> GKEKubernetesHook: - if self._cluster_url is None or self._ssl_ca_cert is None: - raise AttributeError( - "Cluster url and ssl_ca_cert should be defined before using self.hook method. " - "Try to use self.get_kube_creds method", - ) - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - impersonation_chain=self.impersonation_chain, - ) - - def execute(self, context: Context): - """Execute process of creating Custom Resource.""" - self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( - cluster_name=self.cluster_name, - project_id=self.project_id, - use_internal_ip=self.use_internal_ip, - cluster_hook=self.cluster_hook, - ).fetch_cluster_info() - return super().execute(context) - - -class GKEDeleteCustomResourceOperator(KubernetesDeleteResourceOperator): +class GKEDeleteCustomResourceOperator(GKEBaseOperator, KubernetesDeleteResourceOperator): """ Delete a resource in the specified Google Kubernetes Engine cluster. @@ -1275,49 +878,14 @@ class GKEDeleteCustomResourceOperator(KubernetesDeleteResourceOperator): .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:GKEDeleteCustomResourceOperator` - - :param location: The name of the Google Kubernetes Engine zone or region in which the - cluster resides, e.g. 'us-central1-a' - :param cluster_name: The name of the Google Kubernetes Engine cluster. - :param use_internal_ip: Use the internal IP address as the endpoint. - :param project_id: The Google Developers Console project id - :param gcp_conn_id: The Google cloud connection id to use. This allows for - users to specify a service account. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). """ template_fields: Sequence[str] = tuple( - {"project_id", "location", "cluster_name"} | set(KubernetesDeleteResourceOperator.template_fields) + set(GKEBaseOperator.template_fields) | set(KubernetesDeleteResourceOperator.template_fields) ) - def __init__( - self, - *, - location: str, - cluster_name: str, - use_internal_ip: bool = False, - project_id: str = PROVIDE_PROJECT_ID, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.cluster_name = cluster_name - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - self.use_internal_ip = use_internal_ip - - self._ssl_ca_cert: str | None = None - self._cluster_url: str | None = None + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) if self.gcp_conn_id is None: raise AirflowException( @@ -1326,77 +894,22 @@ def __init__( "called `google_cloud_default`.", ) # There is no need to manage the kube_config file, as it will be generated automatically. - # All Kubernetes parameters (except config_file) are also valid for the GKEStartPodOperator. + # All Kubernetes parameters (except config_file) are also valid for the GKEDeleteCustomResourceOperator. if self.config_file: - raise AirflowException("config_file is not an allowed parameter for the GKEStartPodOperator.") - - @cached_property - def cluster_hook(self) -> GKEHook: - return GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def hook(self) -> GKEKubernetesHook: - if self._cluster_url is None or self._ssl_ca_cert is None: - raise AttributeError( - "Cluster url and ssl_ca_cert should be defined before using self.hook method. " - "Try to use self.get_kube_creds method", + raise AirflowException( + "config_file is not an allowed parameter for the GKEDeleteCustomResourceOperator." ) - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - impersonation_chain=self.impersonation_chain, - ) - - def execute(self, context: Context): - """Execute process of deleting Custom Resource.""" - self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( - cluster_name=self.cluster_name, - project_id=self.project_id, - use_internal_ip=self.use_internal_ip, - cluster_hook=self.cluster_hook, - ).fetch_cluster_info() - return super().execute(context) -class GKEStartKueueJobOperator(GKEStartJobOperator): - """ - Executes a Kubernetes Job in Kueue in the specified Google Kubernetes Engine cluster. - - :param queue_name: The name of the Queue in the cluster - """ +class GKEStartKueueJobOperator(GKEBaseOperator, KubernetesStartKueueJobOperator): + """Executes a Kubernetes Job in Kueue in the specified Google Kubernetes Engine cluster.""" - def __init__( - self, - *, - queue_name: str, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.queue_name = queue_name - - if self.suspend is False: - raise AirflowException( - "The `suspend` parameter can't be False. If you want to use Kueue for running Job" - " in a Kubernetes cluster, set the `suspend` parameter to True.", - ) - elif self.suspend is None: - warnings.warn( - f"You have not set parameter `suspend` in class {self.__class__.__name__}. " - "For running a Job in Kueue the `suspend` parameter should set to True.", - UserWarning, - stacklevel=2, - ) - self.suspend = True - self.labels.update({"kueue.x-k8s.io/queue-name": queue_name}) - self.annotations.update({"kueue.x-k8s.io/queue-name": queue_name}) + template_fields = tuple( + set(GKEBaseOperator.template_fields) | set(KubernetesStartKueueJobOperator.template_fields) + ) -class GKEDeleteJobOperator(KubernetesDeleteJobOperator): +class GKEDeleteJobOperator(GKEBaseOperator, KubernetesDeleteJobOperator): """ Delete a Kubernetes job in the specified Google Kubernetes Engine cluster. @@ -1414,49 +927,14 @@ class GKEDeleteJobOperator(KubernetesDeleteJobOperator): .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:GKEDeleteJobOperator` - - :param location: The name of the Google Kubernetes Engine zone or region in which the - cluster resides, e.g. 'us-central1-a' - :param cluster_name: The name of the Google Kubernetes Engine cluster - :param use_internal_ip: Use the internal IP address as the endpoint. - :param project_id: The Google Developers Console project id - :param gcp_conn_id: The Google cloud connection id to use. This allows for - users to specify a service account. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). """ template_fields: Sequence[str] = tuple( - {"project_id", "location", "cluster_name"} | set(KubernetesDeleteJobOperator.template_fields) + set(GKEBaseOperator.template_fields) | set(KubernetesDeleteJobOperator.template_fields) ) - def __init__( - self, - *, - location: str, - cluster_name: str, - use_internal_ip: bool = False, - project_id: str = PROVIDE_PROJECT_ID, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.cluster_name = cluster_name - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - self.use_internal_ip = use_internal_ip - - self._ssl_ca_cert: str | None = None - self._cluster_url: str | None = None + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) if self.gcp_conn_id is None: raise AirflowException( @@ -1469,41 +947,8 @@ def __init__( if self.config_file: raise AirflowException("config_file is not an allowed parameter for the GKEDeleteJobOperator.") - @cached_property - def cluster_hook(self) -> GKEHook: - return GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def hook(self) -> GKEKubernetesHook: - if self._cluster_url is None or self._ssl_ca_cert is None: - raise AttributeError( - "Cluster url and ssl_ca_cert should be defined before using self.hook method. " - "Try to use self.get_kube_creds method", - ) - - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - ) - - def execute(self, context: Context): - """Execute process of deleting Job.""" - self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( - cluster_name=self.cluster_name, - project_id=self.project_id, - use_internal_ip=self.use_internal_ip, - cluster_hook=self.cluster_hook, - ).fetch_cluster_info() - - return super().execute(context) - -class GKESuspendJobOperator(GoogleCloudBaseOperator): +class GKESuspendJobOperator(GKEBaseOperator): """ Suspend Job by given name. @@ -1512,84 +957,18 @@ class GKESuspendJobOperator(GoogleCloudBaseOperator): :ref:`howto/operator:GKESuspendJobOperator` :param name: The name of the Job to suspend - :param project_id: The Google Developers Console project id. - :param location: The name of the Google Kubernetes Engine zone or region in which the cluster - resides. - :param cluster_name: The name of the Google Kubernetes Engine cluster. :param namespace: The name of the Google Kubernetes Engine namespace. - :param use_internal_ip: Use the internal IP address as the endpoint. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ( - "project_id", - "gcp_conn_id", - "name", - "namespace", - "cluster_name", - "location", - "impersonation_chain", - ) + template_fields: Sequence[str] = tuple({"name", "namespace"} | set(GKEBaseOperator.template_fields)) operator_extra_links = (KubernetesEngineJobLink(),) - def __init__( - self, - *, - name: str, - location: str, - namespace: str, - cluster_name: str, - project_id: str = PROVIDE_PROJECT_ID, - use_internal_ip: bool = False, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) + def __init__(self, name: str, namespace: str, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.location = location self.name = name self.namespace = namespace - self.cluster_name = cluster_name - self.use_internal_ip = use_internal_ip - self.impersonation_chain = impersonation_chain - self.job: V1Job | None = None - self._ssl_ca_cert: str - self._cluster_url: str - - @cached_property - def cluster_hook(self) -> GKEHook: - return GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def hook(self) -> GKEKubernetesHook: - self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( - cluster_name=self.cluster_name, - project_id=self.project_id, - use_internal_ip=self.use_internal_ip, - cluster_hook=self.cluster_hook, - ).fetch_cluster_info() - - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - ) def execute(self, context: Context) -> None: self.job = self.hook.patch_namespaced_job( @@ -1607,7 +986,7 @@ def execute(self, context: Context) -> None: return k8s.V1Job.to_dict(self.job) -class GKEResumeJobOperator(GoogleCloudBaseOperator): +class GKEResumeJobOperator(GKEBaseOperator): """ Resume Job by given name. @@ -1616,84 +995,18 @@ class GKEResumeJobOperator(GoogleCloudBaseOperator): :ref:`howto/operator:GKEResumeJobOperator` :param name: The name of the Job to resume - :param project_id: The Google Developers Console project id. - :param location: The name of the Google Kubernetes Engine zone or region in which the cluster - resides. - :param cluster_name: The name of the Google Kubernetes Engine cluster. :param namespace: The name of the Google Kubernetes Engine namespace. - :param use_internal_ip: Use the internal IP address as the endpoint. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). """ - template_fields: Sequence[str] = ( - "project_id", - "gcp_conn_id", - "name", - "namespace", - "cluster_name", - "location", - "impersonation_chain", - ) + template_fields: Sequence[str] = tuple({"name", "namespace"} | set(GKEBaseOperator.template_fields)) operator_extra_links = (KubernetesEngineJobLink(),) - def __init__( - self, - *, - name: str, - location: str, - namespace: str, - cluster_name: str, - project_id: str = PROVIDE_PROJECT_ID, - use_internal_ip: bool = False, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) + def __init__(self, name: str, namespace: str, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.location = location self.name = name self.namespace = namespace - self.cluster_name = cluster_name - self.use_internal_ip = use_internal_ip - self.impersonation_chain = impersonation_chain - self.job: V1Job | None = None - self._ssl_ca_cert: str - self._cluster_url: str - - @cached_property - def cluster_hook(self) -> GKEHook: - return GKEHook( - gcp_conn_id=self.gcp_conn_id, - location=self.location, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def hook(self) -> GKEKubernetesHook: - self._cluster_url, self._ssl_ca_cert = GKEClusterAuthDetails( - cluster_name=self.cluster_name, - project_id=self.project_id, - use_internal_ip=self.use_internal_ip, - cluster_hook=self.cluster_hook, - ).fetch_cluster_info() - - return GKEKubernetesHook( - gcp_conn_id=self.gcp_conn_id, - cluster_url=self._cluster_url, - ssl_ca_cert=self._ssl_ca_cert, - ) def execute(self, context: Context) -> None: self.job = self.hook.patch_namespaced_job( diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index 96d2271f406a6..c1c4493592d95 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -188,7 +188,7 @@ additional-extras: - apache-beam[gcp] - name: cncf.kubernetes dependencies: - - apache-airflow-providers-cncf-kubernetes>=7.2.0 + - apache-airflow-providers-cncf-kubernetes>10.0.1 - name: leveldb dependencies: - plyvel>=1.5.1 diff --git a/providers/tests/cncf/kubernetes/hooks/test_kubernetes.py b/providers/tests/cncf/kubernetes/hooks/test_kubernetes.py index d8e73b99b700e..e512e4203f6eb 100644 --- a/providers/tests/cncf/kubernetes/hooks/test_kubernetes.py +++ b/providers/tests/cncf/kubernetes/hooks/test_kubernetes.py @@ -22,10 +22,11 @@ import tempfile from asyncio import Future from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import kubernetes import pytest +from kubernetes.client import V1Deployment, V1DeploymentStatus from kubernetes.client.rest import ApiException from kubernetes.config import ConfigException from sqlalchemy.orm import make_transient @@ -52,6 +53,23 @@ JOB_NAME = "test-job" CONTAINER_NAME = "test-container" POLL_INTERVAL = 100 +YAML_URL = "https://test-yaml-url.com" + +NOT_READY_DEPLOYMENT = V1Deployment( + status=V1DeploymentStatus( + observed_generation=1, + ready_replicas=None, + replicas=None, + unavailable_replicas=1, + updated_replicas=None, + ) +) +READY_DEPLOYMENT = V1Deployment( + status=V1DeploymentStatus( + observed_generation=1, ready_replicas=1, replicas=1, unavailable_replicas=None, updated_replicas=1 + ) +) +DEPLOYMENT_NAME = "test-deployment-name" class DeprecationRemovalRequired(AirflowException): ... @@ -666,6 +684,112 @@ def test_create_job_retries_three_times(self, mock_client, mock_json_dumps): assert mock_client.create_namespaced_job.call_count == 3 + @pytest.mark.parametrize( + "given_namespace, expected_namespace", + [ + (None, "default-namespace"), + ("given-namespace", "given-namespace"), + ], + ) + @pytest.mark.parametrize( + "given_client, expected_client", + [ + (None, mock.MagicMock()), + (mock_client := mock.MagicMock(), mock_client), + ], + ) + @patch(f"{HOOK_MODULE}.utils.create_from_yaml") + @patch(f"{HOOK_MODULE}.KubernetesHook.get_namespace") + @patch(f"{HOOK_MODULE}.KubernetesHook.api_client", new_callable=PropertyMock) + def test_apply_from_yaml_file( + self, + mock_api_client, + mock_get_namespace, + mock_create_from_yaml, + given_client, + expected_client, + given_namespace, + expected_namespace, + ): + initial_kwargs = dict( + api_client=given_client, + yaml_objects=mock.MagicMock(), + yaml_file=mock.MagicMock(), + verbose=mock.MagicMock(), + namespace=given_namespace, + ) + expected_kwargs = dict( + k8s_client=expected_client, + yaml_objects=initial_kwargs["yaml_objects"], + yaml_file=initial_kwargs["yaml_file"], + verbose=initial_kwargs["verbose"], + namespace=expected_namespace, + ) + mock_api_client.return_value = expected_client + mock_get_namespace.return_value = expected_namespace + + KubernetesHook().apply_from_yaml_file(**initial_kwargs) + + mock_create_from_yaml.assert_called_once_with(**expected_kwargs) + if given_client is None: + mock_api_client.assert_called_once() + if given_namespace is None: + mock_get_namespace.assert_called_once() + + @mock.patch(HOOK_MODULE + ".sleep") + @mock.patch(HOOK_MODULE + ".KubernetesHook.log") + @mock.patch(HOOK_MODULE + ".KubernetesHook.get_deployment_status") + def test_check_kueue_deployment_running(self, mock_get_deployment_status, mock_log, mock_sleep): + mock_get_deployment_status.side_effect = [ + NOT_READY_DEPLOYMENT, + READY_DEPLOYMENT, + ] + + KubernetesHook().check_kueue_deployment_running(name=DEPLOYMENT_NAME, namespace=NAMESPACE) + + mock_log.info.assert_called_once_with("Waiting until Deployment will be ready...") + mock_sleep.assert_called_once_with(2.0) + + @mock.patch(HOOK_MODULE + ".KubernetesHook.log") + @mock.patch(HOOK_MODULE + ".KubernetesHook.get_deployment_status") + def test_check_kueue_deployment_raise_exception(self, mock_get_deployment_status, mock_log): + mock_get_deployment_status.side_effect = ValueError + + with pytest.raises(ValueError): + KubernetesHook().check_kueue_deployment_running(name=DEPLOYMENT_NAME, namespace=NAMESPACE) + + mock_log.exception.assert_called_once_with("Exception occurred while checking for Deployment status.") + + @mock.patch(f"{HOOK_MODULE}.yaml") + @mock.patch(f"{HOOK_MODULE}.requests") + def test_get_yaml_content_from_file(self, mock_requests, mock_yaml): + mock_get = mock_requests.get + mock_response = mock_get.return_value + expected_response_text = "test response text" + mock_response.text = expected_response_text + mock_response.status_code = 200 + expected_result = list(mock_yaml.safe_load_all.return_value) + + result = KubernetesHook().get_yaml_content_from_file(YAML_URL) + + mock_get.assert_called_with(YAML_URL, allow_redirects=True) + mock_yaml.safe_load_all.assert_called_with(expected_response_text) + assert result == expected_result + + @mock.patch(f"{HOOK_MODULE}.yaml") + @mock.patch(f"{HOOK_MODULE}.requests") + def test_get_yaml_content_from_file_error(self, mock_requests, mock_yaml): + mock_get = mock_requests.get + mock_response = mock_get.return_value + mock_response.status_code = 500 + expected_error_message = "Was not able to read the yaml file from given URL" + + with pytest.raises(AirflowException, match=expected_error_message): + KubernetesHook().get_yaml_content_from_file(YAML_URL) + + mock_get.assert_called_with(YAML_URL, allow_redirects=True) + mock_yaml.safe_load_all.assert_not_called() + class TestKubernetesHookIncorrectConfiguration: @pytest.mark.parametrize( diff --git a/providers/tests/cncf/kubernetes/operators/test_kueue.py b/providers/tests/cncf/kubernetes/operators/test_kueue.py new file mode 100644 index 0000000000000..286b22247e001 --- /dev/null +++ b/providers/tests/cncf/kubernetes/operators/test_kueue.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from unittest import mock + +import pytest +from kubernetes.utils import FailToCreateError + +from airflow.exceptions import AirflowException +from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator +from airflow.providers.cncf.kubernetes.operators.kueue import ( + KubernetesInstallKueueOperator, + KubernetesStartKueueJobOperator, +) + +TEST_TASK_ID = "test_task" +TEST_K8S_CONN_ID = "test_kubernetes_conn_id" +TEST_ERROR_CLASS = "TestException" +TEST_ERROR_BODY = "Test Exception Body" +TEST_QUEUE_NAME = "test-queue-name" + +KUEUE_VERSION = "9.0.1" +KUEUE_YAML_URL = f"https://github.com/kubernetes-sigs/kueue/releases/download/{KUEUE_VERSION}/manifests.yaml" +KUEUE_OPERATORS_PATH = "airflow.providers.cncf.kubernetes.operators.kueue.{}" + + +class TestKubernetesInstallKueueOperator: + def setup_method(self): + self.operator = KubernetesInstallKueueOperator( + task_id=TEST_TASK_ID, + kueue_version=KUEUE_VERSION, + kubernetes_conn_id=TEST_K8S_CONN_ID, + ) + + def test_template_fields(self): + expected_template_fields = {"kueue_version", "kubernetes_conn_id"} + assert set(KubernetesInstallKueueOperator.template_fields) == expected_template_fields + + @mock.patch(KUEUE_OPERATORS_PATH.format("KubernetesHook")) + def test_hook(self, mock_hook): + mock_hook_instance = mock_hook.return_value + + actual_hook = self.operator.hook + + mock_hook.assert_called_once_with(conn_id=TEST_K8S_CONN_ID) + assert actual_hook == mock_hook_instance + + @mock.patch(KUEUE_OPERATORS_PATH.format("KubernetesInstallKueueOperator.log")) + @mock.patch(KUEUE_OPERATORS_PATH.format("KubernetesHook")) + def test_execute(self, mock_hook, mock_log): + mock_get_yaml_content_from_file = mock_hook.return_value.get_yaml_content_from_file + mock_yaml_objects = mock_get_yaml_content_from_file.return_value + + self.operator.execute(context=mock.MagicMock()) + + mock_get_yaml_content_from_file.assert_called_once_with(kueue_yaml_url=KUEUE_YAML_URL) + mock_hook.return_value.apply_from_yaml_file.assert_called_once_with(yaml_objects=mock_yaml_objects) + mock_hook.return_value.check_kueue_deployment_running.assert_called_once_with( + name="kueue-controller-manager", + namespace="kueue-system", + ) + mock_log.info.assert_called_once_with("Kueue installed successfully!") + + @mock.patch(KUEUE_OPERATORS_PATH.format("KubernetesInstallKueueOperator.log")) + @mock.patch(KUEUE_OPERATORS_PATH.format("KubernetesHook")) + def test_execute_already_exist(self, mock_hook, mock_log): + mock_get_yaml_content_from_file = mock_hook.return_value.get_yaml_content_from_file + mock_yaml_objects = mock_get_yaml_content_from_file.return_value + mock_apply_from_yaml_file = mock_hook.return_value.apply_from_yaml_file + api_exceptions = [mock.MagicMock(body=json.dumps({"reason": "AlreadyExists"})) for _ in range(4)] + mock_apply_from_yaml_file.side_effect = FailToCreateError(api_exceptions) + + self.operator.execute(context=mock.MagicMock()) + + mock_get_yaml_content_from_file.assert_called_once_with(kueue_yaml_url=KUEUE_YAML_URL) + mock_apply_from_yaml_file.assert_called_once_with(yaml_objects=mock_yaml_objects) + mock_hook.return_value.check_kueue_deployment_running.assert_not_called() + mock_log.info.assert_called_once_with("Kueue is already enabled for the cluster") + + @mock.patch(KUEUE_OPERATORS_PATH.format("KubernetesInstallKueueOperator.log")) + @mock.patch(KUEUE_OPERATORS_PATH.format("KubernetesHook")) + def test_execute_error(self, mock_hook, mock_log): + mock_get_yaml_content_from_file = mock_hook.return_value.get_yaml_content_from_file + mock_yaml_objects = mock_get_yaml_content_from_file.return_value + mock_apply_from_yaml_file = mock_hook.return_value.apply_from_yaml_file + api_exceptions = [ + mock.MagicMock(body=json.dumps({"reason": "AlreadyExists"})), + mock.MagicMock(body=json.dumps({"reason": TEST_ERROR_CLASS, "body": TEST_ERROR_BODY})), + mock.MagicMock(body=json.dumps({"reason": TEST_ERROR_CLASS, "body": TEST_ERROR_BODY})), + ] + mock_apply_from_yaml_file.side_effect = FailToCreateError(api_exceptions) + expected_error_message = f"{TEST_ERROR_BODY}\n{TEST_ERROR_BODY}" + + with pytest.raises(AirflowException, match=expected_error_message): + self.operator.execute(context=mock.MagicMock()) + + mock_get_yaml_content_from_file.assert_called_once_with(kueue_yaml_url=KUEUE_YAML_URL) + mock_apply_from_yaml_file.assert_called_once_with(yaml_objects=mock_yaml_objects) + mock_hook.return_value.check_kueue_deployment_running.assert_not_called() + mock_log.info.assert_called_once_with("Kueue is already enabled for the cluster") + + +class TestKubernetesStartKueueJobOperator: + def test_template_fields(self): + expected_template_fields = {"queue_name"} | set(KubernetesJobOperator.template_fields) + assert set(KubernetesStartKueueJobOperator.template_fields) == expected_template_fields + + def test_init(self): + operator = KubernetesStartKueueJobOperator( + task_id=TEST_TASK_ID, queue_name=TEST_QUEUE_NAME, suspend=True + ) + + assert operator.queue_name == TEST_QUEUE_NAME + assert operator.suspend is True + assert operator.labels == {"kueue.x-k8s.io/queue-name": TEST_QUEUE_NAME} + assert operator.annotations == {"kueue.x-k8s.io/queue-name": TEST_QUEUE_NAME} + + def test_init_suspend_is_false(self): + expected_error_message = ( + "The `suspend` parameter can't be False. If you want to use Kueue for running Job" + " in a Kubernetes cluster, set the `suspend` parameter to True." + ) + with pytest.raises(AirflowException, match=expected_error_message): + KubernetesStartKueueJobOperator(task_id=TEST_TASK_ID, queue_name=TEST_QUEUE_NAME, suspend=False) + + @mock.patch(KUEUE_OPERATORS_PATH.format("warnings")) + def test_init_suspend_is_none(self, mock_warnings): + operator = KubernetesStartKueueJobOperator( + task_id=TEST_TASK_ID, + queue_name=TEST_QUEUE_NAME, + ) + + assert operator.queue_name == TEST_QUEUE_NAME + assert operator.suspend is True + assert operator.labels == {"kueue.x-k8s.io/queue-name": TEST_QUEUE_NAME} + assert operator.annotations == {"kueue.x-k8s.io/queue-name": TEST_QUEUE_NAME} + mock_warnings.warn.assert_called_once_with( + "You have not set parameter `suspend` in class KubernetesStartKueueJobOperator. " + "For running a Job in Kueue the `suspend` parameter should set to True.", + UserWarning, + stacklevel=2, + ) diff --git a/providers/tests/deprecations_ignore.yml b/providers/tests/deprecations_ignore.yml index 6e3adc6875fb6..317b3b3a45153 100644 --- a/providers/tests/deprecations_ignore.yml +++ b/providers/tests/deprecations_ignore.yml @@ -94,7 +94,6 @@ - providers/tests/google/cloud/operators/test_dataproc.py::test_scale_cluster_operator_extra_links - providers/tests/google/cloud/operators/test_dataproc.py::test_submit_spark_job_operator_extra_links - providers/tests/google/cloud/operators/test_gcs.py::TestGoogleCloudStorageListOperator::test_execute__delimiter -- providers/tests/google/cloud/operators/test_kubernetes_engine.py::TestGoogleCloudPlatformContainerOperator::test_create_execute_error_body - providers/tests/google/cloud/operators/test_life_sciences.py::TestLifeSciencesRunPipelineOperator::test_executes - providers/tests/google/cloud/operators/test_life_sciences.py::TestLifeSciencesRunPipelineOperator::test_executes_without_project_id - providers/tests/google/cloud/transfers/test_gcs_to_gcs.py::TestGoogleCloudStorageToCloudStorageOperator::test_copy_files_into_a_folder diff --git a/providers/tests/google/cloud/hooks/test_kubernetes_engine.py b/providers/tests/google/cloud/hooks/test_kubernetes_engine.py index 9f563246ff44d..848a1b4b8c12a 100644 --- a/providers/tests/google/cloud/hooks/test_kubernetes_engine.py +++ b/providers/tests/google/cloud/hooks/test_kubernetes_engine.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import copy from asyncio import Future from unittest import mock @@ -24,7 +25,6 @@ import pytest from google.cloud.container_v1 import ClusterManagerAsyncClient from google.cloud.container_v1.types import Cluster -from kubernetes.client.models import V1Deployment, V1DeploymentStatus from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.kubernetes_engine import ( @@ -146,20 +146,6 @@ "status": {"phase": "Running"}, }, } -NOT_READY_DEPLOYMENT = V1Deployment( - status=V1DeploymentStatus( - observed_generation=1, - ready_replicas=None, - replicas=None, - unavailable_replicas=1, - updated_replicas=None, - ) -) -READY_DEPLOYMENT = V1Deployment( - status=V1DeploymentStatus( - observed_generation=1, ready_replicas=1, replicas=1, unavailable_replicas=None, updated_replicas=1 - ) -) @pytest.mark.db_test @@ -462,25 +448,32 @@ def _get_config(self): def _get_credentials(self): return self.credentials - @mock.patch("kubernetes.client.AppsV1Api") - def test_check_kueue_deployment_running(self, gke_deployment_hook, caplog): - self.gke_hook.get_credentials = self._get_credentials - gke_deployment_hook.return_value.read_namespaced_deployment_status.side_effect = [ - NOT_READY_DEPLOYMENT, - READY_DEPLOYMENT, - ] - self.gke_hook.check_kueue_deployment_running(name=CLUSTER_NAME, namespace=NAMESPACE) - - assert "Waiting until Deployment will be ready..." in caplog.text + @pytest.mark.parametrize( + "api_client, expected_client", + [ + (None, mock.MagicMock()), + (mock_client := mock.MagicMock(), mock_client), + ], + ) + @mock.patch(GKE_STRING.format("super")) + @mock.patch(GKE_STRING.format("GKEKubernetesHook.get_conn")) + def test_apply_from_yaml_file(self, mock_get_conn, mock_super, api_client, expected_client): + kwargs = dict( + api_client=api_client, + yaml_file=mock.MagicMock(), + yaml_objects=mock.MagicMock(), + verbose=mock.MagicMock(), + namespace=mock.MagicMock(), + ) + expected_kwargs = copy.deepcopy(kwargs) + expected_kwargs["api_client"] = expected_client + mock_get_conn.return_value = expected_client - @mock.patch("kubernetes.client.AppsV1Api") - def test_check_kueue_deployment_raise_exception(self, gke_deployment_hook, caplog): - self.gke_hook.get_credentials = self._get_credentials - gke_deployment_hook.return_value.read_namespaced_deployment_status.side_effect = ValueError() - with pytest.raises(ValueError): - self.gke_hook.check_kueue_deployment_running(name=CLUSTER_NAME, namespace=NAMESPACE) + self.gke_hook.apply_from_yaml_file(**kwargs) - assert "Exception occurred while checking for Deployment status." in caplog.text + if api_client is None: + mock_get_conn.assert_called_once() + mock_super.return_value.apply_from_yaml_file.assert_called_once_with(**expected_kwargs) class TestGKEKubernetesAsyncHook: diff --git a/providers/tests/google/cloud/operators/test_kubernetes_engine.py b/providers/tests/google/cloud/operators/test_kubernetes_engine.py index 3127b5d89ca9e..3dce4a508eba4 100644 --- a/providers/tests/google/cloud/operators/test_kubernetes_engine.py +++ b/providers/tests/google/cloud/operators/test_kubernetes_engine.py @@ -17,19 +17,26 @@ # under the License. from __future__ import annotations -import json -import os +from copy import deepcopy from unittest import mock -from unittest.mock import mock_open +from unittest.mock import PropertyMock, call import pytest +from google.api_core.exceptions import AlreadyExists from google.cloud.container_v1.types import Cluster, NodePool -from kubernetes.client.models import V1Deployment, V1DeploymentStatus -from kubernetes.utils.create_from_yaml import FailToCreateError -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred -from airflow.models import Connection -from airflow.providers.cncf.kubernetes.operators.job import KubernetesDeleteJobOperator, KubernetesJobOperator +from airflow.exceptions import ( + AirflowException, + AirflowProviderDeprecationWarning, +) +from airflow.providers.cncf.kubernetes.operators.job import ( + KubernetesDeleteJobOperator, + KubernetesJobOperator, +) +from airflow.providers.cncf.kubernetes.operators.kueue import ( + KubernetesInstallKueueOperator, + KubernetesStartKueueJobOperator, +) from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.operators.resource import ( KubernetesCreateResourceOperator, @@ -37,12 +44,15 @@ ) from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction from airflow.providers.google.cloud.operators.kubernetes_engine import ( + GKEBaseOperator, + GKEClusterAuthDetails, GKECreateClusterOperator, GKECreateCustomResourceOperator, GKEDeleteClusterOperator, GKEDeleteCustomResourceOperator, GKEDeleteJobOperator, GKEDescribeJobOperator, + GKEListJobsOperator, GKEResumeJobOperator, GKEStartJobOperator, GKEStartKueueInsideClusterOperator, @@ -50,123 +60,287 @@ GKEStartPodOperator, GKESuspendJobOperator, ) -from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEStartPodTrigger - -TEST_GCP_PROJECT_ID = "test-id" -PROJECT_LOCATION = "test-location" -PROJECT_TASK_ID = "test-task-id" -CLUSTER_NAME = "test-cluster-name" -QUEUE_NAME = "test-queue-name" - -PROJECT_BODY = {"name": "test-name"} -PROJECT_BODY_CREATE_DICT = {"name": "test-name", "initial_node_count": 1} -PROJECT_BODY_CREATE_DICT_NODE_POOLS = { - "name": "test-name", + +TEST_PROJECT_ID = "test-id" +TEST_LOCATION = "test-location" +TEST_TASK_ID = "test-task-id" +TEST_CONN_ID = "test-conn-id" +TEST_IMPERSONATION_CHAIN = "test-sa-@google.com" + +TEST_OPERATION_NAME = "test-operation-name" +TEST_SELF_LINK = "test-self-link" +TEST_TARGET_LINK = "test-target-link" +TEST_IMAGE = "bash" +TEST_POLL_INTERVAL = 20.0 + +GKE_CLUSTER_NAME = "test-cluster-name" +GKE_CLUSTER_ENDPOINT = "test-host" +GKE_CLUSTER_PRIVATE_ENDPOINT = "test-private-host" +GKE_CLUSTER_URL = f"https://{GKE_CLUSTER_ENDPOINT}" +GKE_CLUSTER_PRIVATE_URL = f"https://{GKE_CLUSTER_PRIVATE_ENDPOINT}" +GKE_SSL_CA_CERT = "TEST_SSL_CA_CERT_CONTENT" + +GKE_CLUSTER_CREATE_BODY_DICT = { + "name": GKE_CLUSTER_NAME, "node_pools": [{"name": "a_node_pool", "initial_node_count": 1}], } - -PROJECT_BODY_CREATE_CLUSTER = Cluster(name="test-name", initial_node_count=1) -PROJECT_BODY_CREATE_CLUSTER_NODE_POOLS = Cluster( - name="test-name", node_pools=[NodePool(name="a_node_pool", initial_node_count=1)] +GKE_CLUSTER_CREATE_BODY_OBJECT = Cluster( + name=GKE_CLUSTER_NAME, node_pools=[NodePool(name="a_node_pool", initial_node_count=1)] ) +GKE_CLUSTER_CREATE_BODY_DICT_DEPRECATED = {"name": GKE_CLUSTER_NAME, "initial_node_count": 1} +GKE_CLUSTER_CREATE_BODY_OBJECT_DEPRECATED = Cluster(name=GKE_CLUSTER_NAME, initial_node_count=1) -TASK_NAME = "test-task-name" -JOB_NAME = "test-job" -POD_NAME = "test-pod" -NAMESPACE = ("default",) -IMAGE = "bash" -JOB_POLL_INTERVAL = 20.0 - -GCLOUD_COMMAND = "gcloud container clusters get-credentials {} --zone {} --project {}" -KUBE_ENV_VAR = "KUBECONFIG" -FILE_NAME = "/tmp/mock_name" -KUB_OP_PATH = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.{}" -GKE_HOOK_MODULE_PATH = "airflow.providers.google.cloud.operators.kubernetes_engine" -GKE_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEHook" -GKE_KUBERNETES_HOOK = f"{GKE_HOOK_MODULE_PATH}.GKEKubernetesHook" -GKE_K8S_HOOK_PATH = f"{GKE_HOOK_MODULE_PATH}.GKEKubernetesHook" -KUB_OPERATOR_EXEC = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute" -KUB_JOB_OPERATOR_EXEC = "airflow.providers.cncf.kubernetes.operators.job.KubernetesJobOperator.execute" -KUB_CREATE_RES_OPERATOR_EXEC = ( - "airflow.providers.cncf.kubernetes.operators.resource.KubernetesCreateResourceOperator.execute" -) -KUB_DELETE_RES_OPERATOR_EXEC = ( - "airflow.providers.cncf.kubernetes.operators.resource.KubernetesDeleteResourceOperator.execute" -) -DEL_KUB_JOB_OPERATOR_EXEC = ( - "airflow.providers.cncf.kubernetes.operators.job.KubernetesDeleteJobOperator.execute" -) -TEMP_FILE = "tempfile.NamedTemporaryFile" -GKE_OP_PATH = "airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator" -GKE_CREATE_CLUSTER_PATH = ( - "airflow.providers.google.cloud.operators.kubernetes_engine.GKECreateClusterOperator" -) -GKE_JOB_OP_PATH = "airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartJobOperator" -GKE_CLUSTER_AUTH_DETAILS_PATH = ( - "airflow.providers.google.cloud.operators.kubernetes_engine.GKEClusterAuthDetails" -) -CLUSTER_URL = "https://test-host" -CLUSTER_PRIVATE_URL = "https://test-private-host" -SSL_CA_CERT = "TEST_SSL_CA_CERT_CONTENT" -KUEUE_VERSION = "v0.5.1" -IMPERSONATION_CHAIN = "sa-@google.com" -USE_INTERNAL_API = False -READY_DEPLOYMENT = V1Deployment( - status=V1DeploymentStatus( - observed_generation=1, ready_replicas=1, replicas=1, unavailable_replicas=None, updated_replicas=1 - ) -) -VALID_RESOURCE_YAML = """ -apiVersion: v1 -kind: PersistentVolumeClaim -metadata: - name: test_pvc -spec: - accessModes: - - ReadWriteOnce - storageClassName: standard - resources: - requests: - storage: 5Gi -""" -KUEUE_YAML_URL = "http://test-url/config.yaml" - - -class TestGoogleCloudPlatformContainerOperator: +K8S_KUEUE_VERSION = "v0.9.1" +K8S_JOB_NAME = "test-job" +K8S_POD_NAME = "test-pod" +K8S_NAMESPACE = "default" + +GKE_OPERATORS_PATH = "airflow.providers.google.cloud.operators.kubernetes_engine.{}" + + +class TestGKEClusterAuthDetails: @pytest.mark.parametrize( - "body", + "use_internal_ip, endpoint, private_endpoint, expected_cluster_url", [ - PROJECT_BODY_CREATE_DICT, - PROJECT_BODY_CREATE_DICT_NODE_POOLS, - PROJECT_BODY_CREATE_CLUSTER, - PROJECT_BODY_CREATE_CLUSTER_NODE_POOLS, + (False, GKE_CLUSTER_ENDPOINT, GKE_CLUSTER_PRIVATE_ENDPOINT, GKE_CLUSTER_URL), + (True, GKE_CLUSTER_ENDPOINT, GKE_CLUSTER_PRIVATE_ENDPOINT, GKE_CLUSTER_PRIVATE_URL), ], ) - @mock.patch(GKE_HOOK_PATH) - def test_create_execute(self, mock_hook, body): - print("type: ", type(body)) - if body == PROJECT_BODY_CREATE_DICT or body == PROJECT_BODY_CREATE_CLUSTER: - with pytest.warns( - AirflowProviderDeprecationWarning, - match="The body field 'initial_node_count' is deprecated. Use 'node_pool.initial_node_count' instead.", - ): - operator = GKECreateClusterOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - body=body, - task_id=PROJECT_TASK_ID, - ) - else: - operator = GKECreateClusterOperator( - project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, body=body, task_id=PROJECT_TASK_ID - ) + def test_fetch_cluster_info(self, use_internal_ip, endpoint, private_endpoint, expected_cluster_url): + mock_cluster = mock.MagicMock( + endpoint=endpoint, + private_cluster_config=mock.MagicMock(private_endpoint=private_endpoint), + master_auth=mock.MagicMock(cluster_ca_certificate=GKE_SSL_CA_CERT), + ) + mock_cluster_hook = mock.MagicMock(get_cluster=mock.MagicMock(return_value=mock_cluster)) + + cluster_auth_details = GKEClusterAuthDetails( + cluster_name=GKE_CLUSTER_NAME, + project_id=TEST_PROJECT_ID, + use_internal_ip=use_internal_ip, + cluster_hook=mock_cluster_hook, + ) + + cluster_url, ssl_ca_cert = cluster_auth_details.fetch_cluster_info() + assert expected_cluster_url == cluster_url + assert ssl_ca_cert == GKE_SSL_CA_CERT + mock_cluster_hook.get_cluster.assert_called_once_with( + name=GKE_CLUSTER_NAME, project_id=TEST_PROJECT_ID + ) + + +class TestGKEBaseOperator: + def setup_method(self): + self.operator = GKEBaseOperator( + task_id=TEST_TASK_ID, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + gcp_conn_id=TEST_CONN_ID, + use_internal_ip=False, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + def test_template_fields(self): + expected_template_fields = { + "location", + "cluster_name", + "use_internal_ip", + "project_id", + "gcp_conn_id", + "impersonation_chain", + } + assert set(self.operator.template_fields) == expected_template_fields + + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_cluster_hook(self, mock_cluster_hook): + actual_hook = self.operator.cluster_hook + + assert actual_hook == mock_cluster_hook.return_value + mock_cluster_hook.assert_called_once_with( + gcp_conn_id=TEST_CONN_ID, + location=TEST_LOCATION, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEClusterAuthDetails")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEKubernetesHook")) + def test_hook(self, mock_hook, mock_cluster_auth_details, mock_cluster_hook): + mock_cluster_auth_details.return_value.fetch_cluster_info.return_value = ( + GKE_CLUSTER_URL, + GKE_SSL_CA_CERT, + ) + actual_hook = self.operator.hook + + assert actual_hook == mock_hook.return_value + mock_hook.assert_called_once_with( + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + cluster_url=GKE_CLUSTER_URL, + ssl_ca_cert=GKE_SSL_CA_CERT, + enable_tcp_keepalive=False, + ) + + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEClusterAuthDetails")) + def test_cluster_info(self, mock_cluster_auth_details, mock_cluster_hook): + mock_fetch_cluster_info = mock_cluster_auth_details.return_value.fetch_cluster_info + mock_fetch_cluster_info.return_value = (GKE_CLUSTER_URL, GKE_SSL_CA_CERT) + + cluster_info = self.operator.cluster_info - operator.execute(context=mock.MagicMock()) - mock_hook.return_value.create_cluster.assert_called_once_with( - cluster=body, - project_id=TEST_GCP_PROJECT_ID, + assert cluster_info == (GKE_CLUSTER_URL, GKE_SSL_CA_CERT) + mock_cluster_auth_details.assert_called_once_with( + cluster_name=self.operator.cluster_name, + project_id=self.operator.project_id, + use_internal_ip=self.operator.use_internal_ip, + cluster_hook=self.operator.cluster_hook, + ) + mock_fetch_cluster_info.assert_called_once_with() + + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEBaseOperator.cluster_info"), new_callable=PropertyMock) + def test_cluster_url(self, mock_cluster_info, mock_cluster_hook): + mock_cluster_info.return_value = (GKE_CLUSTER_URL, GKE_SSL_CA_CERT) + + cluster_url = self.operator.cluster_url + + assert cluster_url == GKE_CLUSTER_URL + + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEBaseOperator.cluster_info"), new_callable=PropertyMock) + def test_ssl_ca_cert(self, mock_cluster_info, mock_cluster_hook): + mock_cluster_info.return_value = (GKE_CLUSTER_URL, GKE_SSL_CA_CERT) + + ssl_ca_cert = self.operator.ssl_ca_cert + + assert ssl_ca_cert == GKE_SSL_CA_CERT + + +class TestGKEDeleteClusterOperator: + def setup_method(self): + self.operator = GKEDeleteClusterOperator( + task_id=TEST_TASK_ID, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + def test_template_fields(self): + expected_template_fields = {"api_version", "deferrable", "poll_interval"} | set( + GKEBaseOperator.template_fields + ) + assert set(self.operator.template_fields) == expected_template_fields + + @pytest.mark.parametrize("missing_parameter", ["project_id", "location", "cluster_name"]) + def test_check_input(self, missing_parameter): + setattr(self.operator, missing_parameter, None) + + with pytest.raises(AirflowException): + self.operator._check_input() + + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_execute(self, mock_cluster_hook): + mock_delete_cluster = mock_cluster_hook.return_value.delete_cluster + mock_operation = mock_delete_cluster.return_value + mock_operation.self_link = TEST_SELF_LINK + + result = self.operator.execute(context=mock.MagicMock()) + + mock_delete_cluster.assert_called_once_with( + name=GKE_CLUSTER_NAME, + project_id=TEST_PROJECT_ID, wait_to_complete=True, ) + assert result == TEST_SELF_LINK + + @mock.patch(GKE_OPERATORS_PATH.format("GKEOperationTrigger")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEDeleteClusterOperator.defer")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_deferrable(self, mock_cluster_hook, mock_defer, mock_trigger): + mock_delete_cluster = mock_cluster_hook.return_value.delete_cluster + mock_operation = mock_delete_cluster.return_value + mock_operation.name = TEST_OPERATION_NAME + mock_trigger_instance = mock_trigger.return_value + self.operator.deferrable = True + + self.operator.execute(context=mock.MagicMock()) + + mock_delete_cluster.assert_called_once_with( + name=GKE_CLUSTER_NAME, + project_id=TEST_PROJECT_ID, + wait_to_complete=False, + ) + mock_trigger.assert_called_once_with( + operation_name=TEST_OPERATION_NAME, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + poll_interval=10, + ) + mock_defer.assert_called_once_with( + trigger=mock_trigger_instance, + method_name="execute_complete", + ) + + @mock.patch(GKE_OPERATORS_PATH.format("GKEDeleteClusterOperator.log")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_execute_complete(self, cluster_hook, mock_log): + mock_get_operation = cluster_hook.return_value.get_operation + mock_get_operation.return_value.self_link = TEST_SELF_LINK + expected_status, expected_message = "success", "test-message" + event = dict(status=expected_status, message=expected_message, operation_name=TEST_OPERATION_NAME) + + result = self.operator.execute_complete(context=mock.MagicMock(), event=event) + + mock_log.info.assert_called_once_with(expected_message) + mock_get_operation.assert_called_once_with(operation_name=TEST_OPERATION_NAME) + assert result == TEST_SELF_LINK + + @pytest.mark.parametrize("status", ["failed", "error"]) + @mock.patch(GKE_OPERATORS_PATH.format("GKEDeleteClusterOperator.log")) + def test_execute_complete_error(self, mock_log, status): + expected_message = "test-message" + event = dict(status=status, message=expected_message) + + with pytest.raises(AirflowException): + self.operator.execute_complete(context=mock.MagicMock(), event=event) + + mock_log.exception.assert_called_once_with("Trigger ended with one of the failed statuses.") + + +class TestGKECreateClusterOperator: + def setup_method(self): + self.operator = GKECreateClusterOperator( + task_id=TEST_TASK_ID, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + body=GKE_CLUSTER_CREATE_BODY_DICT, + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + def test_template_fields(self): + expected_template_fields = {"body", "api_version", "deferrable", "poll_interval"} | set( + GKEBaseOperator.template_fields + ) + assert set(GKECreateClusterOperator.template_fields) == expected_template_fields + + @pytest.mark.parametrize("body", [GKE_CLUSTER_CREATE_BODY_DICT, GKE_CLUSTER_CREATE_BODY_OBJECT]) + def test_body(self, body): + op = GKECreateClusterOperator( + task_id=TEST_TASK_ID, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + body=body, + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + assert op.cluster_name == GKE_CLUSTER_NAME @pytest.mark.parametrize( "body", @@ -213,230 +387,264 @@ def test_create_execute(self, mock_hook, body): )(), ], ) - @mock.patch(GKE_HOOK_PATH) - def test_create_execute_error_body(self, mock_hook, body): - with pytest.raises(AirflowException): - GKECreateClusterOperator( - project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, body=body, task_id=PROJECT_TASK_ID - ) + def test_body_error(self, body): + deprecated_fields = {"initial_node_count", "node_config", "zone", "instance_group_urls"} + used_deprecated_fields = {} + if body: + if isinstance(body, dict): + used_deprecated_fields = set(body.keys()).intersection(deprecated_fields) + else: + used_deprecated_fields = {getattr(body, field, None) for field in deprecated_fields} + used_deprecated_fields = {field for field in used_deprecated_fields if field} - @mock.patch(GKE_HOOK_PATH) - def test_create_execute_error_project_id(self, mock_hook): - with pytest.raises(AirflowException): - GKECreateClusterOperator(location=PROJECT_LOCATION, body=PROJECT_BODY, task_id=PROJECT_TASK_ID) + if used_deprecated_fields: + with pytest.raises(AirflowProviderDeprecationWarning): + GKECreateClusterOperator( + project_id=TEST_PROJECT_ID, location=TEST_LOCATION, body=body, task_id=TEST_TASK_ID + ) + else: + with pytest.raises(AirflowException): + GKECreateClusterOperator( + project_id=TEST_PROJECT_ID, location=TEST_LOCATION, body=body, task_id=TEST_TASK_ID + ) - @mock.patch(GKE_HOOK_PATH) - def test_create_execute_error_location(self, mock_hook): - with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'location'"): + @pytest.mark.parametrize( + "deprecated_field_name, deprecated_field_value", + [ + ("initial_node_count", 1), + ("node_config", mock.MagicMock()), + ("zone", mock.MagicMock()), + ("instance_group_urls", mock.MagicMock()), + ], + ) + def test_alert_deprecated_body_fields(self, deprecated_field_name, deprecated_field_value): + body = deepcopy(GKE_CLUSTER_CREATE_BODY_DICT) + body[deprecated_field_name] = deprecated_field_value + with pytest.raises(AirflowProviderDeprecationWarning): GKECreateClusterOperator( - project_id=TEST_GCP_PROJECT_ID, body=PROJECT_BODY, task_id=PROJECT_TASK_ID + project_id=TEST_PROJECT_ID, location=TEST_LOCATION, body=body, task_id=TEST_TASK_ID ) - @mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook") - @mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKECreateClusterOperator.defer") - def test_create_execute_call_defer_method(self, mock_defer_method, mock_hook): - operator = GKECreateClusterOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - body=PROJECT_BODY_CREATE_DICT_NODE_POOLS, - task_id=PROJECT_TASK_ID, - deferrable=True, - ) + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineClusterLink")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_execute(self, mock_cluster_hook, mock_link): + mock_create_cluster = mock_cluster_hook.return_value.create_cluster + mock_operation = mock_create_cluster.return_value + mock_operation.target_link = TEST_TARGET_LINK + mock_context = mock.MagicMock() - operator.execute(mock.MagicMock()) + result = self.operator.execute(context=mock_context) - mock_defer_method.assert_called_once() - - @mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook") - def test_delete_execute(self, mock_hook): - operator = GKEDeleteClusterOperator( - project_id=TEST_GCP_PROJECT_ID, - name=CLUSTER_NAME, - location=PROJECT_LOCATION, - task_id=PROJECT_TASK_ID, + mock_link.persist.assert_called_once_with( + context=mock_context, task_instance=self.operator, cluster=GKE_CLUSTER_CREATE_BODY_DICT ) - - operator.execute(None) - mock_hook.return_value.delete_cluster.assert_called_once_with( - name=CLUSTER_NAME, - project_id=TEST_GCP_PROJECT_ID, + mock_create_cluster.assert_called_once_with( + cluster=GKE_CLUSTER_CREATE_BODY_DICT, + project_id=TEST_PROJECT_ID, wait_to_complete=True, ) + assert result == TEST_TARGET_LINK + + @mock.patch(GKE_OPERATORS_PATH.format("GKECreateClusterOperator.log")) + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineClusterLink")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_execute_error(self, mock_cluster_hook, mock_link, mock_log): + mock_create_cluster = mock_cluster_hook.return_value.create_cluster + expected_error_message = "test-message" + mock_create_cluster.side_effect = AlreadyExists(message=expected_error_message) + mock_get_cluster = mock_cluster_hook.return_value.get_cluster + mock_get_cluster.return_value.self_link = TEST_SELF_LINK + mock_context = mock.MagicMock() + + result = self.operator.execute(context=mock_context) + + mock_link.persist.assert_called_once_with( + context=mock_context, task_instance=self.operator, cluster=GKE_CLUSTER_CREATE_BODY_DICT + ) + mock_create_cluster.assert_called_once_with( + cluster=GKE_CLUSTER_CREATE_BODY_DICT, + project_id=TEST_PROJECT_ID, + wait_to_complete=True, + ) + mock_get_cluster.assert_called_once_with( + name=GKE_CLUSTER_NAME, + project_id=TEST_PROJECT_ID, + ) + mock_log.info.assert_called_once_with("Assuming Success: %s", expected_error_message) + assert result == TEST_SELF_LINK + + @mock.patch(GKE_OPERATORS_PATH.format("GKEOperationTrigger")) + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineClusterLink")) + @mock.patch(GKE_OPERATORS_PATH.format("GKECreateClusterOperator.defer")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_deferrable(self, mock_cluster_hook, mock_defer, mock_link, mock_trigger): + mock_create_cluster = mock_cluster_hook.return_value.create_cluster + mock_operation = mock_create_cluster.return_value + mock_operation.name = TEST_OPERATION_NAME + mock_trigger_instance = mock_trigger.return_value + mock_context = mock.MagicMock() + self.operator.deferrable = True - @mock.patch(GKE_HOOK_PATH) - def test_delete_execute_error_project_id(self, mock_hook): - with pytest.raises(AirflowException): - GKEDeleteClusterOperator(location=PROJECT_LOCATION, name=CLUSTER_NAME, task_id=PROJECT_TASK_ID) + self.operator.execute(context=mock_context) - @mock.patch(GKE_HOOK_PATH) - def test_delete_execute_error_cluster_name(self, mock_hook): - with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'name'"): - GKEDeleteClusterOperator( - project_id=TEST_GCP_PROJECT_ID, location=PROJECT_LOCATION, task_id=PROJECT_TASK_ID - ) + mock_link.persist.assert_called_once_with( + context=mock_context, task_instance=self.operator, cluster=GKE_CLUSTER_CREATE_BODY_DICT + ) + mock_create_cluster.assert_called_once_with( + cluster=GKE_CLUSTER_CREATE_BODY_DICT, + project_id=TEST_PROJECT_ID, + wait_to_complete=False, + ) + mock_trigger.assert_called_once_with( + operation_name=TEST_OPERATION_NAME, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + poll_interval=10, + ) + mock_defer.assert_called_once_with( + trigger=mock_trigger_instance, + method_name="execute_complete", + ) - @mock.patch(GKE_HOOK_PATH) - def test_delete_execute_error_location(self, mock_hook): - with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'location'"): - GKEDeleteClusterOperator( - project_id=TEST_GCP_PROJECT_ID, name=CLUSTER_NAME, task_id=PROJECT_TASK_ID - ) + @mock.patch(GKE_OPERATORS_PATH.format("GKECreateClusterOperator.log")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_execute_complete(self, cluster_hook, mock_log): + mock_get_operation = cluster_hook.return_value.get_operation + mock_get_operation.return_value.target_link = TEST_TARGET_LINK + expected_status, expected_message = "success", "test-message" + event = dict(status=expected_status, message=expected_message, operation_name=TEST_OPERATION_NAME) - @mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKEHook") - @mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKEDeleteClusterOperator.defer") - def test_delete_execute_call_defer_method(self, mock_defer_method, mock_hook): - operator = GKEDeleteClusterOperator( - project_id=TEST_GCP_PROJECT_ID, - name=CLUSTER_NAME, - location=PROJECT_LOCATION, - task_id=PROJECT_TASK_ID, - deferrable=True, - ) + result = self.operator.execute_complete(context=mock.MagicMock(), event=event) + + mock_log.info.assert_called_once_with(expected_message) + mock_get_operation.assert_called_once_with(operation_name=TEST_OPERATION_NAME) + assert result == TEST_TARGET_LINK - operator.execute(None) + @pytest.mark.parametrize("status", ["failed", "error"]) + @mock.patch(GKE_OPERATORS_PATH.format("GKECreateClusterOperator.log")) + def test_execute_complete_error(self, mock_log, status): + expected_message = "test-message" + event = dict(status=status, message=expected_message) + + with pytest.raises(AirflowException): + self.operator.execute_complete(context=mock.MagicMock(), event=event) - mock_defer_method.assert_called_once() + mock_log.exception.assert_called_once_with("Trigger ended with one of the failed statuses.") -class TestGKEPodOperator: +class TestGKEStartKueueInsideClusterOperator: def setup_method(self): - self.gke_op = GKEStartPodOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - on_finish_action=OnFinishAction.KEEP_POD, - ) - self.gke_op.pod = mock.MagicMock( - name=TASK_NAME, - namespace=NAMESPACE, + self.operator = GKEStartKueueInsideClusterOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + kueue_version=K8S_KUEUE_VERSION, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + use_internal_ip=False, ) def test_template_fields(self): - assert set(KubernetesPodOperator.template_fields).issubset(GKEStartPodOperator.template_fields) - - @mock.patch.dict(os.environ, {}) - @mock.patch(KUB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info") - def test_execute(self, fetch_cluster_info_mock, file_mock, exec_mock): - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - def test_config_file_throws_error(self): - with pytest.raises(AirflowException): - GKEStartPodOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - config_file="/path/to/alternative/kubeconfig", - on_finish_action=OnFinishAction.KEEP_POD, - ) - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(KUB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info") - def test_execute_with_impersonation_service_account( - self, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - self.gke_op.impersonation_chain = "test_account@example.com" - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(KUB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info") - def test_execute_with_impersonation_service_chain_one_element( - self, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - self.gke_op.impersonation_chain = ["test_account@example.com"] - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - @pytest.mark.db_test - @pytest.mark.parametrize("use_internal_ip", [True, False]) - @mock.patch(f"{GKE_HOOK_PATH}.get_cluster") - def test_cluster_info(self, get_cluster_mock, use_internal_ip): - get_cluster_mock.return_value = mock.MagicMock( - **{ - "endpoint": "test-host", - "private_cluster_config.private_endpoint": "test-private-host", - "master_auth.cluster_ca_certificate": SSL_CA_CERT, - } + expected_template_fields = set(GKEBaseOperator.template_fields) | set( + KubernetesInstallKueueOperator.template_fields ) - gke_op = GKEStartPodOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - use_internal_ip=use_internal_ip, - on_finish_action=OnFinishAction.KEEP_POD, + assert set(GKEStartKueueInsideClusterOperator.template_fields) == expected_template_fields + + def test_enable_tcp_keepalive(self): + assert self.operator.enable_tcp_keepalive + + @mock.patch(GKE_OPERATORS_PATH.format("super")) + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineClusterLink")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_execute(self, mock_hook, mock_link, mock_super): + mock_get_cluster = mock_hook.return_value.get_cluster + mock_cluster = mock_get_cluster.return_value + mock_check_cluster_autoscaling_ability = mock_hook.return_value.check_cluster_autoscaling_ability + mock_check_cluster_autoscaling_ability.return_value = True + mock_context = mock.MagicMock() + + self.operator.execute(context=mock_context) + + mock_get_cluster.assert_called_once_with( + name=GKE_CLUSTER_NAME, + project_id=TEST_PROJECT_ID, ) - cluster_url, ssl_ca_cert = gke_op.fetch_cluster_info() - - assert cluster_url == CLUSTER_PRIVATE_URL if use_internal_ip else CLUSTER_URL - assert ssl_ca_cert == SSL_CA_CERT - - @pytest.mark.db_test - def test_default_gcp_conn_id(self): - gke_op = GKEStartPodOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - on_finish_action=OnFinishAction.KEEP_POD, + mock_link.persist.assert_called_once_with( + context=mock_context, + task_instance=self.operator, + cluster=mock_cluster, + ) + mock_check_cluster_autoscaling_ability.assert_called_once_with(cluster=mock_cluster) + mock_super.assert_called_once() + mock_super.return_value.execute.assert_called_once_with(mock_context) + + @mock.patch(GKE_OPERATORS_PATH.format("GKEStartKueueInsideClusterOperator.log")) + @mock.patch(GKE_OPERATORS_PATH.format("super")) + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineClusterLink")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_execute_not_scalable(self, mock_hook, mock_link, mock_super, mock_log): + mock_get_cluster = mock_hook.return_value.get_cluster + mock_cluster = mock_get_cluster.return_value + mock_check_cluster_autoscaling_ability = mock_hook.return_value.check_cluster_autoscaling_ability + mock_check_cluster_autoscaling_ability.return_value = False + mock_context = mock.MagicMock() + + self.operator.execute(context=mock_context) + + mock_get_cluster.assert_called_once_with( + name=GKE_CLUSTER_NAME, + project_id=TEST_PROJECT_ID, + ) + mock_link.persist.assert_called_once_with( + context=mock_context, + task_instance=self.operator, + cluster=mock_cluster, + ) + mock_check_cluster_autoscaling_ability.assert_called_once_with(cluster=mock_cluster) + mock_super.assert_not_called() + mock_log.info.assert_called_once_with( + "Cluster doesn't have ability to autoscale, will not install Kueue inside. Aborting" ) - gke_op._cluster_url = CLUSTER_URL - gke_op._ssl_ca_cert = SSL_CA_CERT - hook = gke_op.hook - assert hook.gcp_conn_id == "google_cloud_default" - @mock.patch( - "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", - return_value=Connection(conn_id="test_conn"), - ) - def test_gcp_conn_id(self, get_con_mock): - gke_op = GKEStartPodOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - gcp_conn_id="test_conn", +class TestGKEStartPodOperator: + def setup_method(self): + self.operator = GKEStartPodOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + name=K8S_POD_NAME, + namespace=K8S_NAMESPACE, + image=TEST_IMAGE, on_finish_action=OnFinishAction.KEEP_POD, + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - gke_op._cluster_url = CLUSTER_URL - gke_op._ssl_ca_cert = SSL_CA_CERT - hook = gke_op.hook - assert hook.gcp_conn_id == "test_conn" + def test_template_fields(self): + expected_template_fields = ( + {"on_finish_action", "deferrable"} + | (set(KubernetesPodOperator.template_fields) - {"is_delete_operator_pod", "regional"}) + | set(GKEBaseOperator.template_fields) + ) + assert set(GKEStartPodOperator.template_fields) == expected_template_fields + + def test_config_file_error(self): + with pytest.raises(AirflowException): + GKEStartPodOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + name=K8S_POD_NAME, + namespace=K8S_NAMESPACE, + image=TEST_IMAGE, + config_file="/path/to/alternative/kubeconfig", + on_finish_action=OnFinishAction.KEEP_POD, + ) @pytest.mark.parametrize( "compatible_kpo, kwargs, expected_attributes", @@ -483,373 +691,137 @@ def test_on_finish_action_handler( expected_attributes, ): kpo_init_args_mock = mock.MagicMock(**{"parameters": ["on_finish_action"] if compatible_kpo else []}) - with mock.patch("inspect.signature", return_value=kpo_init_args_mock): if "is_delete_operator_pod" in kwargs: - with pytest.warns( - AirflowProviderDeprecationWarning, - match="`is_delete_operator_pod` parameter is deprecated, please use `on_finish_action`", - ): - op = GKEStartPodOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, + with pytest.raises(AirflowProviderDeprecationWarning): + GKEStartPodOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + name=K8S_POD_NAME, + namespace=K8S_NAMESPACE, + image=TEST_IMAGE, **kwargs, ) elif "on_finish_action" not in kwargs: - with pytest.warns( - AirflowProviderDeprecationWarning, - match="You have not set parameter `on_finish_action` in class GKEStartPodOperator. Currently the default for this parameter is `keep_pod` but in a future release the default will be changed to `delete_pod`. To ensure pods are not deleted in the future you will need to set `on_finish_action=keep_pod` explicitly.", - ): - op = GKEStartPodOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, + with pytest.raises(AirflowProviderDeprecationWarning): + GKEStartPodOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + name=K8S_POD_NAME, + namespace=K8S_NAMESPACE, + image=TEST_IMAGE, **kwargs, ) else: op = GKEStartPodOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + name=K8S_POD_NAME, + namespace=K8S_NAMESPACE, + image=TEST_IMAGE, **kwargs, ) - for expected_attr in expected_attributes: - assert op.__getattribute__(expected_attr) == expected_attributes[expected_attr] - - -class TestGKEStartKueueInsideClusterOperator: - @pytest.fixture(autouse=True) - def setup_test(self): - self.gke_op = GKEStartKueueInsideClusterOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - kueue_version=KUEUE_VERSION, - impersonation_chain=IMPERSONATION_CHAIN, - use_internal_ip=USE_INTERNAL_API, - ) - self.gke_op._cluster_url = CLUSTER_URL - self.gke_op._ssl_ca_cert = SSL_CA_CERT - - @pytest.mark.flaky(reruns=5) - @pytest.mark.db_test - @mock.patch.dict(os.environ, {}) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(f"{GKE_KUBERNETES_HOOK}.check_kueue_deployment_running") - @mock.patch(GKE_KUBERNETES_HOOK) - def test_execute(self, mock_pod_hook, mock_deployment, mock_hook, fetch_cluster_info_mock, file_mock): - mock_pod_hook.return_value.apply_from_yaml_file.side_effect = mock.MagicMock() - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - mock_hook.return_value.get_cluster.return_value = PROJECT_BODY_CREATE_CLUSTER - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - @pytest.mark.flaky(reruns=5) - @mock.patch.dict(os.environ, {}) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_KUBERNETES_HOOK) - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_KUBERNETES_HOOK) - def test_execute_autoscaled_cluster( - self, mock_pod_hook, mock_hook, mock_depl_hook, fetch_cluster_info_mock, file_mock, caplog - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - mock_hook.return_value.get_cluster.return_value = mock.MagicMock() - mock_pod_hook.return_value.apply_from_yaml_file.side_effect = mock.MagicMock() - mock_hook.return_value.check_cluster_autoscaling_ability.return_value = True - mock_depl_hook.return_value.get_deployment_status.return_value = READY_DEPLOYMENT - self.gke_op.execute(context=mock.MagicMock()) - - assert "Kueue installed successfully!" in caplog.text - - @pytest.mark.flaky(reruns=5) - @mock.patch.dict(os.environ, {}) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_KUBERNETES_HOOK) - def test_execute_autoscaled_cluster_check_error( - self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, file_mock, caplog + for expected_attr in expected_attributes: + assert op.__getattribute__(expected_attr) == expected_attributes[expected_attr] + + @mock.patch(GKE_OPERATORS_PATH.format("GKEStartPodOperator.defer")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEClusterAuthDetails.fetch_cluster_info")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEStartPodTrigger")) + @mock.patch(GKE_OPERATORS_PATH.format("utcnow")) + def test_invoke_defer_method( + self, mock_utcnow, mock_trigger, mock_cluster_hook, mock_fetch_cluster_info, mock_defer ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - mock_hook.return_value.get_cluster.return_value = mock.MagicMock() - mock_hook.return_value.check_cluster_autoscaling_ability.return_value = True - mock_pod_hook.return_value.apply_from_yaml_file.side_effect = FailToCreateError("error") - self.gke_op.execute(context=mock.MagicMock()) - - assert "Kueue is already enabled for the cluster" in caplog.text - - @mock.patch.dict(os.environ, {}) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_KUBERNETES_HOOK) - def test_execute_non_autoscaled_cluster_check_error( - self, mock_pod_hook, mock_hook, fetch_cluster_info_mock, file_mock, caplog - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - mock_hook.return_value.get_cluster.return_value = mock.MagicMock() - mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False - self.gke_op.execute(context=mock.MagicMock()) + mock_trigger_start_time = mock_utcnow.return_value - assert ( - "Cluster doesn't have ability to autoscale, will not install Kueue inside. Aborting" - in caplog.text - ) - mock_pod_hook.assert_not_called() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_account( - self, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - mock_hook.return_value.get_cluster.return_value = PROJECT_BODY_CREATE_CLUSTER - mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False - self.gke_op.impersonation_chain = "test_account@example.com" - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_chain_one_element( - self, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - mock_hook.return_value.get_cluster.return_value = PROJECT_BODY_CREATE_CLUSTER - mock_hook.return_value.check_cluster_autoscaling_ability.return_value = False - self.gke_op.impersonation_chain = ["test_account@example.com"] - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - @pytest.mark.db_test - def test_default_gcp_conn_id(self): - gke_op = GKEStartKueueInsideClusterOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - kueue_version=KUEUE_VERSION, - impersonation_chain=IMPERSONATION_CHAIN, - use_internal_ip=USE_INTERNAL_API, - ) - gke_op._cluster_url = CLUSTER_URL - gke_op._ssl_ca_cert = SSL_CA_CERT - hook = gke_op.cluster_hook + mock_metadata = mock.MagicMock() + mock_metadata.name = K8S_POD_NAME + mock_metadata.namespace = K8S_NAMESPACE + self.operator.pod = mock.MagicMock(metadata=mock_metadata) + mock_fetch_cluster_info.return_value = GKE_CLUSTER_URL, GKE_SSL_CA_CERT + mock_get_logs = mock.MagicMock() + self.operator.get_logs = mock_get_logs + mock_last_log_time = mock.MagicMock() - assert hook.gcp_conn_id == "google_cloud_default" + self.operator.invoke_defer_method(last_log_time=mock_last_log_time) - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'})), - ) - def test_gcp_conn_id(self, mock_get_credentials): - gke_op = GKEStartKueueInsideClusterOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - kueue_version=KUEUE_VERSION, - impersonation_chain=IMPERSONATION_CHAIN, - use_internal_ip=USE_INTERNAL_API, - gcp_conn_id="test_conn", + mock_trigger.assert_called_once_with( + pod_name=K8S_POD_NAME, + pod_namespace=K8S_NAMESPACE, + trigger_start_time=mock_trigger_start_time, + cluster_url=GKE_CLUSTER_URL, + ssl_ca_cert=GKE_SSL_CA_CERT, + get_logs=mock_get_logs, + startup_timeout=120, + cluster_context=None, + poll_interval=2, + in_cluster=None, + base_container_name="base", + on_finish_action=OnFinishAction.KEEP_POD, + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + logging_interval=None, + last_log_time=mock_last_log_time, + ) + mock_defer.assert_called_once_with( + trigger=mock_trigger.return_value, + method_name="trigger_reentry", ) - gke_op._cluster_url = CLUSTER_URL - gke_op._ssl_ca_cert = SSL_CA_CERT - hook = gke_op.cluster_hook - - assert hook.gcp_conn_id == "test_conn" - - @mock.patch(f"{GKE_HOOK_MODULE_PATH}.requests") - @mock.patch(f"{GKE_HOOK_MODULE_PATH}.yaml") - def test_get_yaml_content_from_file(self, mock_yaml, mock_requests): - yaml_content_expected = [mock.MagicMock(), mock.MagicMock()] - mock_yaml.safe_load_all.return_value = yaml_content_expected - response_text_expected = "response test expected" - mock_requests.get.return_value = mock.MagicMock(status_code=200, text=response_text_expected) - - yaml_content_actual = GKEStartKueueInsideClusterOperator._get_yaml_content_from_file(KUEUE_YAML_URL) - - assert yaml_content_actual == yaml_content_expected - mock_requests.get.assert_called_once_with(KUEUE_YAML_URL, allow_redirects=True) - mock_yaml.safe_load_all.assert_called_once_with(response_text_expected) - - @mock.patch(f"{GKE_HOOK_MODULE_PATH}.requests") - def test_get_yaml_content_from_file_exception(self, mock_requests): - mock_requests.get.return_value = mock.MagicMock(status_code=400) - - with pytest.raises(AirflowException): - GKEStartKueueInsideClusterOperator._get_yaml_content_from_file(KUEUE_YAML_URL) -class TestGKEPodOperatorAsync: +class TestGKEStartJobOperator: def setup_method(self): - self.gke_op = GKEStartPodOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - deferrable=True, - on_finish_action="delete_pod", + self.operator = GKEStartJobOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + name=K8S_JOB_NAME, + namespace=K8S_NAMESPACE, + image=TEST_IMAGE, + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - self.gke_op.pod = mock.MagicMock( - name=TASK_NAME, - namespace=NAMESPACE, + + def test_template_fields(self): + expected_template_fields = ( + {"deferrable", "poll_interval"} + | set(GKEBaseOperator.template_fields) + | set(KubernetesJobOperator.template_fields) ) - self.gke_op._cluster_url = CLUSTER_URL - self.gke_op._ssl_ca_cert = SSL_CA_CERT - - @mock.patch.dict(os.environ, {}) - @mock.patch(KUB_OP_PATH.format("build_pod_request_obj")) - @mock.patch(KUB_OP_PATH.format("get_or_create_pod")) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info") - def test_async_create_pod_should_execute_successfully( - self, fetch_cluster_info_mock, get_con_mock, mocked_pod, mocked_pod_obj, mocker - ): - """ - Asserts that a task is deferred and the GKEStartPodTrigger will be fired - when the GKEStartPodOperator is executed in deferrable mode when deferrable=True. - """ - mock_file = mock_open(read_data='{"a": "b"}') - mocker.patch("builtins.open", mock_file) - - self.gke_op._cluster_url = CLUSTER_URL - self.gke_op._ssl_ca_cert = SSL_CA_CERT - with pytest.raises(TaskDeferred) as exc: - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - assert isinstance(exc.value.trigger, GKEStartPodTrigger) - - @pytest.mark.parametrize("status", ["error", "failed", "timeout"]) - @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_pod") - @mock.patch(KUB_OP_PATH.format("_clean")) - @mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator.hook") - @mock.patch(KUB_OP_PATH.format("_write_logs")) - def test_execute_complete_failure(self, mock_write_logs, mock_gke_hook, mock_clean, mock_get_pod, status): - self.gke_op._cluster_url = CLUSTER_URL - self.gke_op._ssl_ca_cert = SSL_CA_CERT + assert set(GKEStartJobOperator.template_fields) == expected_template_fields + + def test_config_file_throws_error(self): with pytest.raises(AirflowException): - self.gke_op.execute_complete( - context=mock.MagicMock(), - event={"name": "test", "status": status, "namespace": "default", "message": ""}, - cluster_url=self.gke_op._cluster_url, - ssl_ca_cert=self.gke_op._ssl_ca_cert, + GKEStartJobOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + name=K8S_JOB_NAME, + namespace=K8S_NAMESPACE, + image=TEST_IMAGE, + config_file="/path/to/alternative/kubeconfig", ) - mock_write_logs.assert_called_once() - - @mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator.hook") - @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_pod") - @mock.patch(KUB_OP_PATH.format("_clean")) - @mock.patch(KUB_OP_PATH.format("_write_logs")) - def test_execute_complete_success(self, mock_write_logs, mock_clean, mock_get_pod, mock_gke_hook): - self.gke_op._cluster_url = CLUSTER_URL - self.gke_op._ssl_ca_cert = SSL_CA_CERT - self.gke_op.execute_complete( - context=mock.MagicMock(), - event={"name": "test", "status": "success", "namespace": "default"}, - cluster_url=self.gke_op._cluster_url, - ssl_ca_cert=self.gke_op._ssl_ca_cert, - ) - mock_write_logs.assert_called_once() - @mock.patch(KUB_OP_PATH.format("pod_manager")) - @mock.patch( - "airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator.invoke_defer_method" - ) - @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_pod") - @mock.patch(KUB_OP_PATH.format("_clean")) - @mock.patch("airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator.hook") - def test_execute_complete_running( - self, mock_gke_hook, mock_clean, mock_get_pod, mock_invoke_defer_method, mock_pod_manager - ): - self.gke_op._cluster_url = CLUSTER_URL - self.gke_op._ssl_ca_cert = SSL_CA_CERT - self.gke_op.execute_complete( - context=mock.MagicMock(), - event={"name": "test", "status": "running", "namespace": "default"}, - cluster_url=self.gke_op._cluster_url, - ssl_ca_cert=self.gke_op._ssl_ca_cert, - ) - mock_pod_manager.fetch_container_logs.assert_called_once() - mock_invoke_defer_method.assert_called_once() + @mock.patch(GKE_OPERATORS_PATH.format("super")) + def test_execute(self, mock_super): + mock_context = mock.MagicMock() + self.operator.execute(context=mock_context) -class TestGKEStartJobOperator: - def setup_method(self): - self.gke_op = GKEStartJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - ) - self.gke_op.job = mock.MagicMock( - name=TASK_NAME, - namespace=NAMESPACE, - ) + mock_super.assert_called_once() + mock_super.return_value.execute.assert_called_once_with(mock_context) - def test_template_fields(self): - assert set(KubernetesJobOperator.template_fields).issubset(GKEStartJobOperator.template_fields) - - @mock.patch.dict(os.environ, {}) - @mock.patch(KUB_JOB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute(self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch(KUB_JOB_OPERATOR_EXEC) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(f"{GKE_HOOK_MODULE_PATH}.ProvidersManager") - def test_execute_in_deferrable_mode( - self, mock_providers_manager, mock_hook, fetch_cluster_info_mock, exec_mock - ): + @mock.patch(GKE_OPERATORS_PATH.format("super")) + @mock.patch(GKE_OPERATORS_PATH.format("ProvidersManager")) + def test_deferrable(self, mock_providers_manager, mock_super): kubernetes_package_name = "apache-airflow-providers-cncf-kubernetes" mock_providers_manager.return_value.providers = { kubernetes_package_name: mock.MagicMock( @@ -859,14 +831,18 @@ def test_execute_in_deferrable_mode( version="8.0.2", ) } - self.gke_op.deferrable = True + mock_context = mock.MagicMock() + self.operator.deferrable = True - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() + self.operator.execute(context=mock_context) - @mock.patch(f"{GKE_HOOK_MODULE_PATH}.ProvidersManager") - def test_execute_in_deferrable_mode_exception(self, mock_providers_manager): + mock_providers_manager.assert_called_once() + mock_super.assert_called_once() + mock_super.return_value.execute.assert_called_once_with(mock_context) + + @mock.patch(GKE_OPERATORS_PATH.format("super")) + @mock.patch(GKE_OPERATORS_PATH.format("ProvidersManager")) + def test_deferrable_error(self, mock_providers_manager, mock_super): kubernetes_package_name = "apache-airflow-providers-cncf-kubernetes" mock_providers_manager.return_value.providers = { kubernetes_package_name: mock.MagicMock( @@ -876,811 +852,384 @@ def test_execute_in_deferrable_mode_exception(self, mock_providers_manager): version="8.0.1", ) } - self.gke_op.deferrable = True + self.operator.deferrable = True + with pytest.raises(AirflowException): - self.gke_op.execute({}) + self.operator.execute(context=mock.MagicMock()) - @mock.patch(f"{GKE_HOOK_MODULE_PATH}.GKEJobTrigger") - def test_execute_deferrable(self, mock_trigger): - mock_trigger_instance = mock_trigger.return_value + mock_providers_manager.assert_called_once() + mock_super.assert_not_called() + mock_super.return_value.execute.assert_not_called() - op = GKEStartJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - job_poll_interval=JOB_POLL_INTERVAL, - ) - op._ssl_ca_cert = SSL_CA_CERT - op._cluster_url = CLUSTER_URL + @mock.patch(GKE_OPERATORS_PATH.format("GKEStartJobOperator.defer")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEClusterAuthDetails.fetch_cluster_info")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEJobTrigger")) + def test_execute_deferrable(self, mock_trigger, mock_cluster_hook, mock_fetch_cluster_info, mock_defer): + mock_pod_metadata = mock.MagicMock() + mock_pod_metadata.name = K8S_POD_NAME + mock_pod_metadata.namespace = K8S_NAMESPACE + self.operator.pod = mock.MagicMock(metadata=mock_pod_metadata) - with mock.patch.object(op, "job") as mock_job: - mock_metadata = mock_job.metadata - mock_metadata.name = TASK_NAME - mock_metadata.namespace = NAMESPACE + mock_job_metadata = mock.MagicMock() + mock_job_metadata.name = K8S_JOB_NAME + mock_job_metadata.namespace = K8S_NAMESPACE + self.operator.job = mock.MagicMock(metadata=mock_job_metadata) - mock_pod = mock.MagicMock() - mock_pod.metadata.name = POD_NAME - mock_pod.metadata.namespace = NAMESPACE - op.pod = mock_pod + mock_fetch_cluster_info.return_value = GKE_CLUSTER_URL, GKE_SSL_CA_CERT + mock_get_logs = mock.MagicMock() + self.operator.get_logs = mock_get_logs - with mock.patch.object(op, "defer") as mock_defer: - op.execute_deferrable() + self.operator.execute_deferrable() mock_trigger.assert_called_once_with( - cluster_url=CLUSTER_URL, - ssl_ca_cert=SSL_CA_CERT, - job_name=TASK_NAME, - job_namespace=NAMESPACE, - pod_name=POD_NAME, - pod_namespace=NAMESPACE, - base_container_name=op.BASE_CONTAINER_NAME, - gcp_conn_id="google_cloud_default", - poll_interval=JOB_POLL_INTERVAL, - impersonation_chain=None, - get_logs=True, + cluster_url=GKE_CLUSTER_URL, + ssl_ca_cert=GKE_SSL_CA_CERT, + job_name=K8S_JOB_NAME, + job_namespace=K8S_NAMESPACE, + pod_name=K8S_POD_NAME, + pod_namespace=K8S_NAMESPACE, + base_container_name="base", + gcp_conn_id=TEST_CONN_ID, + poll_interval=10.0, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + get_logs=mock_get_logs, do_xcom_push=False, ) mock_defer.assert_called_once_with( - trigger=mock_trigger_instance, + trigger=mock_trigger.return_value, method_name="execute_complete", - kwargs={"cluster_url": CLUSTER_URL, "ssl_ca_cert": SSL_CA_CERT}, - ) - - def test_config_file_throws_error(self): - with pytest.raises(AirflowException): - GKEStartJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - config_file="/path/to/alternative/kubeconfig", - ) - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(KUB_JOB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_account( - self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = "test_account@example.com" - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(KUB_JOB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_chain_one_element( - self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = ["test_account@example.com"] - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - @pytest.mark.db_test - def test_default_gcp_conn_id(self): - gke_op = GKEStartJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - ) - gke_op._cluster_url = CLUSTER_URL - gke_op._ssl_ca_cert = SSL_CA_CERT - hook = gke_op.hook - - assert hook.gcp_conn_id == "google_cloud_default" - - @mock.patch( - "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", - return_value=Connection(conn_id="test_conn"), - ) - def test_gcp_conn_id(self, get_con_mock): - gke_op = GKEStartJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - gcp_conn_id="test_conn", ) - gke_op._cluster_url = CLUSTER_URL - gke_op._ssl_ca_cert = SSL_CA_CERT - hook = gke_op.hook - - assert hook.gcp_conn_id == "test_conn" class TestGKEDescribeJobOperator: def setup_method(self): - self.gke_op = GKEDescribeJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - job_name=JOB_NAME, - namespace=NAMESPACE, - ) - self.gke_op.job = mock.MagicMock( - name=TASK_NAME, - namespace=NAMESPACE, - ) - - @mock.patch.dict(os.environ, {}) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_KUBERNETES_HOOK) - def test_execute(self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock): - mock_job_hook.return_value.get_job.return_value = mock.MagicMock() - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_KUBERNETES_HOOK) - def test_execute_with_impersonation_service_account( - self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock - ): - mock_job_hook.return_value.get_job.return_value = mock.MagicMock() - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = "test_account@example.com" - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_KUBERNETES_HOOK) - def test_execute_with_impersonation_service_chain_one_element( - self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = ["test_account@example.com"] - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - @pytest.mark.db_test - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - def test_default_gcp_conn_id(self, fetch_cluster_info_mock): - gke_op = GKEDescribeJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - job_name=TASK_NAME, - namespace=NAMESPACE, + self.operator = GKEDescribeJobOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + job_name=K8S_JOB_NAME, + namespace=K8S_NAMESPACE, ) - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - hook = gke_op.hook - - assert hook.gcp_conn_id == "google_cloud_default" - @mock.patch( - "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", - return_value=Connection(conn_id="test_conn"), - ) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_gcp_conn_id(self, mock_hook, fetch_cluster_info_mock, mock_gke_conn): - gke_op = GKEDescribeJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - job_name=TASK_NAME, - namespace=NAMESPACE, - gcp_conn_id="test_conn", + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineJobLink")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEDescribeJobOperator.log")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEClusterAuthDetails.fetch_cluster_info")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEKubernetesHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_execute(self, mock_cluster_hook, mock_hook, mock_fetch_cluster_info, mock_log, mock_link): + mock_fetch_cluster_info.return_value = GKE_CLUSTER_URL, GKE_SSL_CA_CERT + mock_job = mock_hook.return_value.get_job.return_value + mock_context = mock.MagicMock() + + self.operator.execute(context=mock_context) + + mock_hook.return_value.get_job.assert_called_once_with(job_name=K8S_JOB_NAME, namespace=K8S_NAMESPACE) + mock_log.info.assert_called_once_with( + "Retrieved description of Job %s from cluster %s:\n %s", + K8S_JOB_NAME, + GKE_CLUSTER_NAME, + mock_job, ) - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - hook = gke_op.hook - - assert hook.gcp_conn_id == "test_conn" + mock_link.persist.assert_called_once_with(context=mock_context, task_instance=self.operator) -class TestGKECreateCustomResourceOperator: +class TestGKEListJobsOperator: def setup_method(self): - self.gke_op = GKECreateCustomResourceOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - yaml_conf=VALID_RESOURCE_YAML, + self.operator = GKEListJobsOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, ) def test_template_fields(self): - assert set(KubernetesCreateResourceOperator.template_fields).issubset( - GKECreateCustomResourceOperator.template_fields + expected_template_fields = {"namespace"} | set(GKEBaseOperator.template_fields) + assert set(GKEListJobsOperator.template_fields) == expected_template_fields + + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineWorkloadsLink")) + @mock.patch(GKE_OPERATORS_PATH.format("V1JobList.to_dict")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEListJobsOperator.log")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEKubernetesHook")) + def test_execute(self, mock_hook, cluster_hook, mock_log, mock_to_dict, mock_link): + mock_list_jobs_from_namespace = mock_hook.return_value.list_jobs_from_namespace + mock_list_jobs_all_namespaces = mock_hook.return_value.list_jobs_all_namespaces + mock_job_1, mock_job_2 = mock.MagicMock(), mock.MagicMock() + mock_jobs = mock.MagicMock(items=[mock_job_1, mock_job_2]) + mock_list_jobs_all_namespaces.return_value = mock_jobs + mock_to_dict_value = mock_to_dict.return_value + + mock_ti = mock.MagicMock() + context = {"ti": mock_ti} + + result = self.operator.execute(context=context) + + mock_list_jobs_all_namespaces.assert_called_once() + mock_list_jobs_from_namespace.assert_not_called() + mock_log.info.assert_has_calls( + [ + call("Retrieved description of Job:\n %s", mock_job_1), + call("Retrieved description of Job:\n %s", mock_job_2), + ] ) - - @mock.patch.dict(os.environ, {}) - @mock.patch(KUB_CREATE_RES_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute(self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(KUB_CREATE_RES_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_account( - self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = "test_account@example.com" - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(KUB_CREATE_RES_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_chain_one_element( - self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = ["test_account@example.com"] - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - -class TestGKEDeleteCustomResourceOperator: - def setup_method(self): - self.gke_op = GKEDeleteCustomResourceOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - yaml_conf=VALID_RESOURCE_YAML, + mock_to_dict.assert_has_calls([call(mock_jobs), call(mock_jobs)]) + mock_ti.xcom_push.assert_called_once_with(key="jobs_list", value=mock_to_dict_value) + mock_link.persist.assert_called_once_with(context=context, task_instance=self.operator) + assert result == mock_to_dict_value + + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineWorkloadsLink")) + @mock.patch(GKE_OPERATORS_PATH.format("V1JobList.to_dict")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEListJobsOperator.log")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEKubernetesHook")) + def test_execute_namespaced(self, mock_hook, cluster_hook, mock_log, mock_to_dict, mock_link): + mock_list_jobs_from_namespace = mock_hook.return_value.list_jobs_from_namespace + mock_list_jobs_all_namespaces = mock_hook.return_value.list_jobs_all_namespaces + mock_job_1, mock_job_2 = mock.MagicMock(), mock.MagicMock() + mock_jobs = mock.MagicMock(items=[mock_job_1, mock_job_2]) + mock_list_jobs_from_namespace.return_value = mock_jobs + mock_to_dict_value = mock_to_dict.return_value + + mock_ti = mock.MagicMock() + context = {"ti": mock_ti} + + self.operator.namespace = K8S_NAMESPACE + result = self.operator.execute(context=context) + + mock_list_jobs_all_namespaces.assert_not_called() + mock_list_jobs_from_namespace.assert_called_once_with(namespace=K8S_NAMESPACE) + mock_log.info.assert_has_calls( + [ + call("Retrieved description of Job:\n %s", mock_job_1), + call("Retrieved description of Job:\n %s", mock_job_2), + ] ) - - def test_template_fields(self): - assert set(KubernetesDeleteResourceOperator.template_fields).issubset( - GKEDeleteCustomResourceOperator.template_fields + mock_to_dict.assert_has_calls([call(mock_jobs), call(mock_jobs)]) + mock_ti.xcom_push.assert_called_once_with(key="jobs_list", value=mock_to_dict_value) + mock_link.persist.assert_called_once_with(context=context, task_instance=self.operator) + assert result == mock_to_dict_value + + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineWorkloadsLink")) + @mock.patch(GKE_OPERATORS_PATH.format("V1JobList.to_dict")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEListJobsOperator.log")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEKubernetesHook")) + def test_execute_not_do_xcom_push(self, mock_hook, cluster_hook, mock_log, mock_to_dict, mock_link): + mock_list_jobs_from_namespace = mock_hook.return_value.list_jobs_from_namespace + mock_list_jobs_all_namespaces = mock_hook.return_value.list_jobs_all_namespaces + mock_job_1, mock_job_2 = mock.MagicMock(), mock.MagicMock() + mock_jobs = mock.MagicMock(items=[mock_job_1, mock_job_2]) + mock_list_jobs_all_namespaces.return_value = mock_jobs + mock_to_dict_value = mock_to_dict.return_value + + mock_ti = mock.MagicMock() + context = {"ti": mock_ti} + + self.operator.do_xcom_push = False + result = self.operator.execute(context=context) + + mock_list_jobs_all_namespaces.assert_called_once() + mock_list_jobs_from_namespace.assert_not_called() + mock_log.info.assert_has_calls( + [ + call("Retrieved description of Job:\n %s", mock_job_1), + call("Retrieved description of Job:\n %s", mock_job_2), + ] ) + mock_to_dict.assert_called_once_with(mock_jobs) + mock_ti.xcom_push.assert_not_called() + mock_link.persist.assert_called_once_with(context=context, task_instance=self.operator) + assert result == mock_to_dict_value - @mock.patch.dict(os.environ, {}) - @mock.patch(KUB_DELETE_RES_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute(self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(KUB_DELETE_RES_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_account( - self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = "test_account@example.com" - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(KUB_DELETE_RES_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_chain_one_element( - self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = ["test_account@example.com"] - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() +class TestGKECreateCustomResourceOperator: + def test_template_fields(self): + assert set(GKECreateCustomResourceOperator.template_fields) == set( + GKEBaseOperator.template_fields + ) | set(KubernetesCreateResourceOperator.template_fields) + def test_gcp_conn_id_required(self): + with pytest.raises(AirflowException): + GKECreateCustomResourceOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + yaml_conf_file="/path/to/yaml_conf_file", + gcp_conn_id=None, + ) -class TestGKEStartKueueJobOperator: - def setup_method(self): - self.gke_op = GKEStartKueueJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - queue_name=QUEUE_NAME, - ) - self.gke_op.job = mock.MagicMock( - name=TASK_NAME, - namespace=NAMESPACE, + def test_config_file_throws_error(self): + expected_error_message = ( + "config_file is not an allowed parameter for the GKECreateCustomResourceOperator." ) + with pytest.raises(AirflowException, match=expected_error_message): + GKECreateCustomResourceOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + gcp_conn_id=TEST_CONN_ID, + task_id=TEST_TASK_ID, + yaml_conf_file="/path/to/yaml_conf_file", + config_file="/path/to/alternative/kubeconfig", + ) + +class TestGKEDeleteCustomResourceOperator: def test_template_fields(self): - assert set(GKEStartJobOperator.template_fields).issubset(GKEStartKueueJobOperator.template_fields) - - @mock.patch.dict(os.environ, {}) - @mock.patch(KUB_JOB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute(self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() + assert set(GKEDeleteCustomResourceOperator.template_fields) == set( + GKEBaseOperator.template_fields + ) | set(KubernetesDeleteResourceOperator.template_fields) - def test_config_file_throws_error(self): - with pytest.raises((TypeError, AirflowException), match="missing keyword argument 'queue_name'"): - GKEStartKueueJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - config_file="/path/to/alternative/kubeconfig", + def test_gcp_conn_id_required(self): + with pytest.raises(AirflowException): + GKEDeleteCustomResourceOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + yaml_conf_file="/path/to/yaml_conf_file", + gcp_conn_id=None, ) - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(KUB_JOB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_account( - self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = "test_account@example.com" - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(KUB_JOB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_chain_one_element( - self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = ["test_account@example.com"] - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - @pytest.mark.db_test - def test_default_gcp_conn_id(self): - gke_op = GKEStartKueueJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - queue_name=QUEUE_NAME, + def test_config_file_throws_error(self): + expected_error_message = ( + "config_file is not an allowed parameter for the GKEDeleteCustomResourceOperator." ) - gke_op._cluster_url = CLUSTER_URL - gke_op._ssl_ca_cert = SSL_CA_CERT - hook = gke_op.hook + with pytest.raises(AirflowException, match=expected_error_message): + GKEDeleteCustomResourceOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + gcp_conn_id=TEST_CONN_ID, + task_id=TEST_TASK_ID, + yaml_conf_file="/path/to/yaml_conf_file", + config_file="/path/to/alternative/kubeconfig", + ) - assert hook.gcp_conn_id == "google_cloud_default" - @mock.patch( - "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", - return_value=Connection(conn_id="test_conn"), - ) - def test_gcp_conn_id(self, get_con_mock): - gke_op = GKEStartKueueJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - image=IMAGE, - gcp_conn_id="test_conn", - queue_name=QUEUE_NAME, +class TestGKEStartKueueJobOperator: + def test_template_fields(self): + assert set(GKEStartKueueJobOperator.template_fields) == set(GKEBaseOperator.template_fields) | set( + KubernetesStartKueueJobOperator.template_fields ) - gke_op._cluster_url = CLUSTER_URL - gke_op._ssl_ca_cert = SSL_CA_CERT - hook = gke_op.hook - - assert hook.gcp_conn_id == "test_conn" class TestGKEDeleteJobOperator: - def setup_method(self): - self.gke_op = GKEDeleteJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, + def test_template_fields(self): + assert set(GKEDeleteJobOperator.template_fields) == set(GKEBaseOperator.template_fields) | set( + KubernetesDeleteJobOperator.template_fields ) - def test_template_fields(self): - assert set(KubernetesDeleteJobOperator.template_fields).issubset(GKEDeleteJobOperator.template_fields) - - @mock.patch.dict(os.environ, {}) - @mock.patch(DEL_KUB_JOB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute(self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() + def test_gcp_conn_id_required(self): + with pytest.raises(AirflowException): + GKEDeleteJobOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + name=K8S_JOB_NAME, + namespace=K8S_NAMESPACE, + task_id=TEST_TASK_ID, + gcp_conn_id=None, + ) def test_config_file_throws_error(self): - with pytest.raises(AirflowException): + expected_error_message = "config_file is not an allowed parameter for the GKEDeleteJobOperator." + with pytest.raises(AirflowException, match=expected_error_message): GKEDeleteJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + name=K8S_JOB_NAME, + namespace=K8S_NAMESPACE, + task_id=TEST_TASK_ID, config_file="/path/to/alternative/kubeconfig", ) - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(DEL_KUB_JOB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_account( - self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = "test_account@example.com" - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(DEL_KUB_JOB_OPERATOR_EXEC) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_execute_with_impersonation_service_chain_one_element( - self, mock_hook, fetch_cluster_info_mock, file_mock, exec_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = ["test_account@example.com"] - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - @pytest.mark.db_test - def test_default_gcp_conn_id(self): - gke_op = GKEDeleteJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - ) - gke_op._cluster_url = CLUSTER_URL - gke_op._ssl_ca_cert = SSL_CA_CERT - hook = gke_op.hook - - assert hook.gcp_conn_id == "google_cloud_default" - - @mock.patch( - "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", - return_value=Connection(conn_id="test_conn"), - ) - def test_gcp_conn_id(self, get_con_mock): - gke_op = GKEDeleteJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - gcp_conn_id="test_conn", - ) - gke_op._cluster_url = CLUSTER_URL - gke_op._ssl_ca_cert = SSL_CA_CERT - hook = gke_op.hook - - assert hook.gcp_conn_id == "test_conn" - class TestGKESuspendJobOperator: def setup_method(self): - self.gke_op = GKESuspendJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, + self.operator = GKESuspendJobOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + name=K8S_JOB_NAME, + namespace=K8S_NAMESPACE, ) - def test_config_file_throws_error(self): - with pytest.raises( - (TypeError, AirflowException), match="Invalid arguments were passed to .*\n.*'config_file'" - ): - GKESuspendJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - config_file="/path/to/alternative/kubeconfig", - ) - - @mock.patch.dict(os.environ, {}) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_K8S_HOOK_PATH) - def test_execute(self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock): - mock_job_hook.return_value.get_job.return_value = mock.MagicMock() - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_K8S_HOOK_PATH) - def test_execute_with_impersonation_service_account( - self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock - ): - mock_job_hook.return_value.get_job.return_value = mock.MagicMock() - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = "test_account@example.com" - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_K8S_HOOK_PATH) - def test_execute_with_impersonation_service_chain_one_element( - self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = ["test_account@example.com"] - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - @pytest.mark.db_test - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - def test_default_gcp_conn_id(self, fetch_cluster_info_mock): - gke_op = GKESuspendJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, + def test_template_fields(self): + expected_template_fields = {"name", "namespace"} | set(GKEBaseOperator.template_fields) + assert set(GKESuspendJobOperator.template_fields) == expected_template_fields + + @mock.patch(GKE_OPERATORS_PATH.format("k8s.V1Job.to_dict")) + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineJobLink")) + @mock.patch(GKE_OPERATORS_PATH.format("GKESuspendJobOperator.log")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEKubernetesHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_execute(self, mock_cluster_hook, mock_hook, mock_log, mock_link, mock_to_dict): + mock_patch_namespaced_job = mock_hook.return_value.patch_namespaced_job + mock_job = mock_patch_namespaced_job.return_value + expected_result = mock_to_dict.return_value + mock_context = mock.MagicMock() + + result = self.operator.execute(context=mock_context) + + mock_patch_namespaced_job.assert_called_once_with( + job_name=K8S_JOB_NAME, + namespace=K8S_NAMESPACE, + body={"spec": {"suspend": True}}, ) - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - hook = gke_op.hook - - assert hook.gcp_conn_id == "google_cloud_default" - - @mock.patch( - "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", - return_value=Connection(conn_id="test_conn"), - ) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_gcp_conn_id(self, mock_hook, fetch_cluster_info_mock, mock_gke_conn): - gke_op = GKESuspendJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - gcp_conn_id="test_conn", + mock_log.info.assert_called_once_with( + "Job %s from cluster %s was suspended.", + K8S_JOB_NAME, + GKE_CLUSTER_NAME, ) - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - hook = gke_op.hook - - assert hook.gcp_conn_id == "test_conn" + mock_link.persist.assert_called_once_with(context=mock_context, task_instance=self.operator) + mock_to_dict.assert_called_once_with(mock_job) + assert result == expected_result class TestGKEResumeJobOperator: def setup_method(self): - self.gke_op = GKEResumeJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, + self.operator = GKEResumeJobOperator( + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + cluster_name=GKE_CLUSTER_NAME, + task_id=TEST_TASK_ID, + name=K8S_JOB_NAME, + namespace=K8S_NAMESPACE, ) - def test_config_file_throws_error(self): - with pytest.raises( - (TypeError, AirflowException), match="Invalid arguments were passed to .*\n.*'config_file'" - ): - GKEResumeJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - config_file="/path/to/alternative/kubeconfig", - ) - - @mock.patch.dict(os.environ, {}) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_K8S_HOOK_PATH) - def test_execute(self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock): - mock_job_hook.return_value.get_job.return_value = mock.MagicMock() - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_K8S_HOOK_PATH) - def test_execute_with_impersonation_service_account( - self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock - ): - mock_job_hook.return_value.get_job.return_value = mock.MagicMock() - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = "test_account@example.com" - self.gke_op.execute(context=mock.MagicMock()) - fetch_cluster_info_mock.assert_called_once() - - @mock.patch.dict(os.environ, {}) - @mock.patch( - "airflow.hooks.base.BaseHook.get_connection", - return_value=[Connection(extra=json.dumps({"keyfile_dict": '{"private_key": "r4nd0m_k3y"}'}))], - ) - @mock.patch(TEMP_FILE) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - @mock.patch(GKE_K8S_HOOK_PATH) - def test_execute_with_impersonation_service_chain_one_element( - self, mock_job_hook, mock_hook, fetch_cluster_info_mock, file_mock, get_con_mock - ): - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - self.gke_op.impersonation_chain = ["test_account@example.com"] - self.gke_op.execute(context=mock.MagicMock()) - - fetch_cluster_info_mock.assert_called_once() - - @pytest.mark.db_test - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - def test_default_gcp_conn_id(self, fetch_cluster_info_mock): - gke_op = GKEResumeJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, + def test_template_fields(self): + expected_template_fields = {"name", "namespace"} | set(GKEBaseOperator.template_fields) + assert set(GKEResumeJobOperator.template_fields) == expected_template_fields + + @mock.patch(GKE_OPERATORS_PATH.format("k8s.V1Job.to_dict")) + @mock.patch(GKE_OPERATORS_PATH.format("KubernetesEngineJobLink")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEResumeJobOperator.log")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEKubernetesHook")) + @mock.patch(GKE_OPERATORS_PATH.format("GKEHook")) + def test_execute(self, mock_cluster_hook, mock_hook, mock_log, mock_link, mock_to_dict): + mock_patch_namespaced_job = mock_hook.return_value.patch_namespaced_job + mock_job = mock_patch_namespaced_job.return_value + expected_result = mock_to_dict.return_value + mock_context = mock.MagicMock() + + result = self.operator.execute(context=mock_context) + + mock_patch_namespaced_job.assert_called_once_with( + job_name=K8S_JOB_NAME, + namespace=K8S_NAMESPACE, + body={"spec": {"suspend": False}}, ) - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - hook = gke_op.hook - - assert hook.gcp_conn_id == "google_cloud_default" - - @mock.patch( - "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection", - return_value=Connection(conn_id="test_conn"), - ) - @mock.patch(f"{GKE_CLUSTER_AUTH_DETAILS_PATH}.fetch_cluster_info") - @mock.patch(GKE_HOOK_PATH) - def test_gcp_conn_id(self, mock_hook, fetch_cluster_info_mock, mock_gke_conn): - gke_op = GKEResumeJobOperator( - project_id=TEST_GCP_PROJECT_ID, - location=PROJECT_LOCATION, - cluster_name=CLUSTER_NAME, - task_id=PROJECT_TASK_ID, - name=TASK_NAME, - namespace=NAMESPACE, - gcp_conn_id="test_conn", + mock_log.info.assert_called_once_with( + "Job %s from cluster %s was resumed.", + K8S_JOB_NAME, + GKE_CLUSTER_NAME, ) - fetch_cluster_info_mock.return_value = (CLUSTER_URL, SSL_CA_CERT) - hook = gke_op.hook - - assert hook.gcp_conn_id == "test_conn" + mock_link.persist.assert_called_once_with(context=mock_context, task_instance=self.operator) + mock_to_dict.assert_called_once_with(mock_job) + assert result == expected_result