Skip to content

Commit

Permalink
Add gcloud command to DataprocCreateClusterOperator to be able to cre…
Browse files Browse the repository at this point in the history
…ate dataproc on GKE cluster
  • Loading branch information
Ulada Zakharava authored and MaksYermak committed Nov 19, 2024
1 parent 298f1fd commit 0c56d29
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 3 deletions.
76 changes: 75 additions & 1 deletion providers/src/airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from __future__ import annotations

import shlex
import subprocess
import time
import uuid
from collections.abc import MutableSequence
Expand Down Expand Up @@ -261,6 +263,50 @@ def get_operations_client(self, region: str | None):
"""Create a OperationsClient."""
return self.get_batch_client(region=region).transport.operations_client

def dataproc_options_to_args(self, options: dict) -> list[str]:
"""
Return a formatted cluster parameters from a dictionary of arguments.
:param options: Dictionary with options
:return: List of arguments
"""
if not options:
return []

args: list[str] = []
for attr, value in options.items():
self.log.info("Attribute: %s, value: %s", attr, value)
if value is None or (isinstance(value, bool) and value):
args.append(f"--{attr}")
elif isinstance(value, bool) and not value:
continue
elif isinstance(value, list):
args.extend([f"--{attr}={v}" for v in value])
else:
args.append(f"--{attr}={value}")
return args

def _build_gcloud_command(self, command: list[str], parameters: dict[str, str]) -> list[str]:
return [*command, *(self.dataproc_options_to_args(parameters))]

def _create_dataflow_cluster_with_gcloud(self, cmd: list[str]) -> str:
"""Create a Dataflow cluster with a gcloud command and return the job's ID."""
self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd))
success_code = 0

with self.provide_authorized_gcloud():
proc = subprocess.run(cmd, capture_output=True)

if proc.returncode != success_code:
stderr_last_20_lines = "\n".join(proc.stderr.decode().strip().splitlines()[-20:])
raise AirflowException(
f"Process exit with non-zero exit code. Exit code: {proc.returncode}. Error Details : "
f"{stderr_last_20_lines}"
)

response = proc.stdout.decode().strip()
return response

def wait_for_operation(
self,
operation: Operation,
Expand Down Expand Up @@ -289,7 +335,7 @@ def create_cluster(
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> Operation:
) -> Operation | str:
"""
Create a cluster in a specified project.
Expand Down Expand Up @@ -326,6 +372,34 @@ def create_cluster(
"project_id": project_id,
"cluster_name": cluster_name,
}

if virtual_cluster_config and "kubernetes_cluster_config" in virtual_cluster_config:
kube_config = virtual_cluster_config["kubernetes_cluster_config"]["gke_cluster_config"]
try:
spark_engine_version = virtual_cluster_config["kubernetes_cluster_config"][
"kubernetes_software_config"
]["component_version"]["SPARK"]
except KeyError:
spark_engine_version = "latest"
gke_cluster_name = kube_config["gke_cluster_target"].rsplit("/", 1)[1]
gke_pools = kube_config["node_pool_target"][0]
gke_pool_name = gke_pools["node_pool"].rsplit("/", 1)[1]
gke_pool_role = gke_pools["roles"][0]
gke_pool_machine_type = gke_pools["node_pool_config"]["config"]["machine_type"]
gcp_flags = {
"region": region,
"gke-cluster": gke_cluster_name,
"spark-engine-version": spark_engine_version,
"pools": f"name={gke_pool_name},roles={gke_pool_role.lower()},machineType={gke_pool_machine_type},min=1,max=10",
"setup-workload-identity": None,
}
cmd = self._build_gcloud_command(
command=["gcloud", "dataproc", "clusters", "gke", "create", cluster_name],
parameters=gcp_flags,
)
response = self._create_dataflow_cluster_with_gcloud(cmd=cmd)
return response

if virtual_cluster_config is not None:
cluster["virtual_cluster_config"] = virtual_cluster_config # type: ignore
if cluster_config is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def execute(self, context: Context) -> dict:
try:
# First try to create a new cluster
operation = self._create_cluster(hook)
if not self.deferrable:
if not self.deferrable and type(operation) is not str:
cluster = hook.wait_for_operation(
timeout=self.timeout, result_retry=self.retry, operation=operation
)
Expand Down
51 changes: 51 additions & 0 deletions providers/tests/google/cloud/hooks/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,27 @@
GCP_LOCATION = "global"
GCP_PROJECT = "test-project"
CLUSTER_CONFIG = {"test": "test"}
VIRTUAL_CLUSTER_CONFIG = {
"kubernetes_cluster_config": {
"gke_cluster_config": {
"gke_cluster_target": "projects/project_id/locations/region/clusters/gke_cluster_name",
"node_pool_target": [
{
"node_pool": "projects/project_id/locations/region/clusters/gke_cluster_name/nodePools/dp",
"roles": ["DEFAULT"],
"node_pool_config": {
"config": {
"preemptible": False,
"machine_type": "e2-standard-4",
}
},
}
],
},
"kubernetes_software_config": {"component_version": {"SPARK": "3"}},
},
"staging_bucket": "test-staging-bucket",
}
LABELS = {"test": "test"}
CLUSTER_NAME = "cluster-name"
CLUSTER = {
Expand Down Expand Up @@ -174,6 +195,36 @@ def test_create_cluster(self, mock_client):
timeout=None,
)

@mock.patch(DATAPROC_STRING.format("DataprocHook._create_dataflow_cluster_with_gcloud"))
@mock.patch(DATAPROC_STRING.format("DataprocHook._build_gcloud_command"))
@mock.patch(DATAPROC_STRING.format("DataprocHook.provide_authorized_gcloud"))
@mock.patch(DATAPROC_STRING.format("subprocess.run"))
def test_create_cluster_with_virtual_cluster_config(
self,
mock_run,
mock_provide_authorized_gcloud,
mock_build_gcloud_command,
mock_create_dataflow_cluster_with_gcloud,
):
self.hook.create_cluster(
project_id=GCP_PROJECT,
region=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
cluster_config=CLUSTER_CONFIG,
virtual_cluster_config=VIRTUAL_CLUSTER_CONFIG,
labels=LABELS,
)
mock_build_gcloud_command.assert_called_once_with(
command=["gcloud", "dataproc", "clusters", "gke", "create", CLUSTER_NAME],
parameters={
"region": GCP_LOCATION,
"gke-cluster": "gke_cluster_name",
"spark-engine-version": "3",
"pools": "name=dp,roles=default,machineType=e2-standard-4,min=1,max=10",
"setup-workload-identity": None,
},
)

@mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
def test_delete_cluster(self, mock_client):
self.hook.delete_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
}
],
},
"kubernetes_software_config": {"component_version": {"SPARK": b"3"}},
"kubernetes_software_config": {"component_version": {"SPARK": "3"}},
},
"staging_bucket": "test-staging-bucket",
}
Expand Down

0 comments on commit 0c56d29

Please sign in to comment.