From 0c56d294e1b4238c51448a540717506d12c91f79 Mon Sep 17 00:00:00 2001 From: Ulada Zakharava Date: Mon, 4 Nov 2024 09:45:42 +0000 Subject: [PATCH] Add gcloud command to DataprocCreateClusterOperator to be able to create dataproc on GKE cluster --- .../providers/google/cloud/hooks/dataproc.py | 76 ++++++++++++++++++- .../google/cloud/operators/dataproc.py | 2 +- .../tests/google/cloud/hooks/test_dataproc.py | 51 +++++++++++++ .../cloud/dataproc/example_dataproc_gke.py | 2 +- 4 files changed, 128 insertions(+), 3 deletions(-) diff --git a/providers/src/airflow/providers/google/cloud/hooks/dataproc.py b/providers/src/airflow/providers/google/cloud/hooks/dataproc.py index f8bedec9853ed..9e2f786f6c1eb 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/dataproc.py +++ b/providers/src/airflow/providers/google/cloud/hooks/dataproc.py @@ -19,6 +19,8 @@ from __future__ import annotations +import shlex +import subprocess import time import uuid from collections.abc import MutableSequence @@ -261,6 +263,50 @@ def get_operations_client(self, region: str | None): """Create a OperationsClient.""" return self.get_batch_client(region=region).transport.operations_client + def dataproc_options_to_args(self, options: dict) -> list[str]: + """ + Return a formatted cluster parameters from a dictionary of arguments. + + :param options: Dictionary with options + :return: List of arguments + """ + if not options: + return [] + + args: list[str] = [] + for attr, value in options.items(): + self.log.info("Attribute: %s, value: %s", attr, value) + if value is None or (isinstance(value, bool) and value): + args.append(f"--{attr}") + elif isinstance(value, bool) and not value: + continue + elif isinstance(value, list): + args.extend([f"--{attr}={v}" for v in value]) + else: + args.append(f"--{attr}={value}") + return args + + def _build_gcloud_command(self, command: list[str], parameters: dict[str, str]) -> list[str]: + return [*command, *(self.dataproc_options_to_args(parameters))] + + def _create_dataflow_cluster_with_gcloud(self, cmd: list[str]) -> str: + """Create a Dataflow cluster with a gcloud command and return the job's ID.""" + self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd)) + success_code = 0 + + with self.provide_authorized_gcloud(): + proc = subprocess.run(cmd, capture_output=True) + + if proc.returncode != success_code: + stderr_last_20_lines = "\n".join(proc.stderr.decode().strip().splitlines()[-20:]) + raise AirflowException( + f"Process exit with non-zero exit code. Exit code: {proc.returncode}. Error Details : " + f"{stderr_last_20_lines}" + ) + + response = proc.stdout.decode().strip() + return response + def wait_for_operation( self, operation: Operation, @@ -289,7 +335,7 @@ def create_cluster( retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), - ) -> Operation: + ) -> Operation | str: """ Create a cluster in a specified project. @@ -326,6 +372,34 @@ def create_cluster( "project_id": project_id, "cluster_name": cluster_name, } + + if virtual_cluster_config and "kubernetes_cluster_config" in virtual_cluster_config: + kube_config = virtual_cluster_config["kubernetes_cluster_config"]["gke_cluster_config"] + try: + spark_engine_version = virtual_cluster_config["kubernetes_cluster_config"][ + "kubernetes_software_config" + ]["component_version"]["SPARK"] + except KeyError: + spark_engine_version = "latest" + gke_cluster_name = kube_config["gke_cluster_target"].rsplit("/", 1)[1] + gke_pools = kube_config["node_pool_target"][0] + gke_pool_name = gke_pools["node_pool"].rsplit("/", 1)[1] + gke_pool_role = gke_pools["roles"][0] + gke_pool_machine_type = gke_pools["node_pool_config"]["config"]["machine_type"] + gcp_flags = { + "region": region, + "gke-cluster": gke_cluster_name, + "spark-engine-version": spark_engine_version, + "pools": f"name={gke_pool_name},roles={gke_pool_role.lower()},machineType={gke_pool_machine_type},min=1,max=10", + "setup-workload-identity": None, + } + cmd = self._build_gcloud_command( + command=["gcloud", "dataproc", "clusters", "gke", "create", cluster_name], + parameters=gcp_flags, + ) + response = self._create_dataflow_cluster_with_gcloud(cmd=cmd) + return response + if virtual_cluster_config is not None: cluster["virtual_cluster_config"] = virtual_cluster_config # type: ignore if cluster_config is not None: diff --git a/providers/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/src/airflow/providers/google/cloud/operators/dataproc.py index 270114e5e53ae..d8569e9cb932a 100644 --- a/providers/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/src/airflow/providers/google/cloud/operators/dataproc.py @@ -818,7 +818,7 @@ def execute(self, context: Context) -> dict: try: # First try to create a new cluster operation = self._create_cluster(hook) - if not self.deferrable: + if not self.deferrable and type(operation) is not str: cluster = hook.wait_for_operation( timeout=self.timeout, result_retry=self.retry, operation=operation ) diff --git a/providers/tests/google/cloud/hooks/test_dataproc.py b/providers/tests/google/cloud/hooks/test_dataproc.py index 88839dabb8143..1fdf0d0b7f0e2 100644 --- a/providers/tests/google/cloud/hooks/test_dataproc.py +++ b/providers/tests/google/cloud/hooks/test_dataproc.py @@ -44,6 +44,27 @@ GCP_LOCATION = "global" GCP_PROJECT = "test-project" CLUSTER_CONFIG = {"test": "test"} +VIRTUAL_CLUSTER_CONFIG = { + "kubernetes_cluster_config": { + "gke_cluster_config": { + "gke_cluster_target": "projects/project_id/locations/region/clusters/gke_cluster_name", + "node_pool_target": [ + { + "node_pool": "projects/project_id/locations/region/clusters/gke_cluster_name/nodePools/dp", + "roles": ["DEFAULT"], + "node_pool_config": { + "config": { + "preemptible": False, + "machine_type": "e2-standard-4", + } + }, + } + ], + }, + "kubernetes_software_config": {"component_version": {"SPARK": "3"}}, + }, + "staging_bucket": "test-staging-bucket", +} LABELS = {"test": "test"} CLUSTER_NAME = "cluster-name" CLUSTER = { @@ -174,6 +195,36 @@ def test_create_cluster(self, mock_client): timeout=None, ) + @mock.patch(DATAPROC_STRING.format("DataprocHook._create_dataflow_cluster_with_gcloud")) + @mock.patch(DATAPROC_STRING.format("DataprocHook._build_gcloud_command")) + @mock.patch(DATAPROC_STRING.format("DataprocHook.provide_authorized_gcloud")) + @mock.patch(DATAPROC_STRING.format("subprocess.run")) + def test_create_cluster_with_virtual_cluster_config( + self, + mock_run, + mock_provide_authorized_gcloud, + mock_build_gcloud_command, + mock_create_dataflow_cluster_with_gcloud, + ): + self.hook.create_cluster( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + cluster_config=CLUSTER_CONFIG, + virtual_cluster_config=VIRTUAL_CLUSTER_CONFIG, + labels=LABELS, + ) + mock_build_gcloud_command.assert_called_once_with( + command=["gcloud", "dataproc", "clusters", "gke", "create", CLUSTER_NAME], + parameters={ + "region": GCP_LOCATION, + "gke-cluster": "gke_cluster_name", + "spark-engine-version": "3", + "pools": "name=dp,roles=default,machineType=e2-standard-4,min=1,max=10", + "setup-workload-identity": None, + }, + ) + @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client")) def test_delete_cluster(self, mock_client): self.hook.delete_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) diff --git a/providers/tests/system/google/cloud/dataproc/example_dataproc_gke.py b/providers/tests/system/google/cloud/dataproc/example_dataproc_gke.py index 8048a23283e4c..0be3f53ea6f18 100644 --- a/providers/tests/system/google/cloud/dataproc/example_dataproc_gke.py +++ b/providers/tests/system/google/cloud/dataproc/example_dataproc_gke.py @@ -84,7 +84,7 @@ } ], }, - "kubernetes_software_config": {"component_version": {"SPARK": b"3"}}, + "kubernetes_software_config": {"component_version": {"SPARK": "3"}}, }, "staging_bucket": "test-staging-bucket", }