diff --git a/docs/apache-airflow-providers-google/connections/gcp_sql.rst b/docs/apache-airflow-providers-google/connections/gcp_sql.rst index 79fa99c0456b1..2073c462b566a 100644 --- a/docs/apache-airflow-providers-google/connections/gcp_sql.rst +++ b/docs/apache-airflow-providers-google/connections/gcp_sql.rst @@ -76,3 +76,40 @@ Extra (optional) .. code-block:: bash export AIRFLOW_CONN_GOOGLE_CLOUD_SQL_DEFAULT='gcpcloudsql://user:XXXXXXXXX@1.1.1.1:3306/mydb?database_type=mysql&project_id=example-project&location=europe-west1&instance=testinstance&use_proxy=True&sql_proxy_use_tcp=False' + +Configuring and using IAM authentication +---------------------------------------- + +.. warning:: + This functionality requires ``gcloud`` command (Google Cloud SDK) must be `installed + `_ on the Airflow worker. + +.. warning:: + IAM authentication working only for Google Service Accounts. + +Configure Service Accounts on Google Cloud IAM side +""""""""""""""""""""""""""""""""""""""""""""""""""" + +For connecting via IAM you need to use Service Account. It can be the same service account which you use for +the ``gcloud`` authentication or an another account. If you decide to use a different account then this +account should be impersonated from the account which used for ``gcloud`` authentication and granted +a ``Service Account Token Creator`` role. More information how to grant a role `here +`_. + +Also the Service Account should be configured for working with IAM. +Here are links describing what should be done before the start: `PostgreSQL +`_ and `MySQL +`_. + +Configure ``gcpcloudsql`` connection with IAM enabling +"""""""""""""""""""""""""""""""""""""""""""""""""""""" + +For using IAM you need to enable ``"use_iam": "True"`` in the ``extra`` field. And specify IAM account in this format +``USERNAME@PROJECT_ID.iam.gserviceaccount.com`` in ``login`` field and empty string in the ``password`` field. + +For example: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_iam.py + :language: python + :start-after: [START howto_operator_cloudsql_iam_connections] + :end-before: [END howto_operator_cloudsql_iam_connections] diff --git a/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py b/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py index b1a1d1883c79e..9db7db5a1ff47 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py +++ b/providers/src/airflow/providers/google/cloud/hooks/cloud_sql.py @@ -26,6 +26,7 @@ import platform import random import re +import shlex import shutil import socket import string @@ -777,6 +778,8 @@ class CloudSQLDatabaseHook(BaseHook): SQL DB. * **use_ssl** - (default False) Whether SSL should be used to connect to Cloud SQL DB. You cannot use proxy and SSL together. + * **use_iam** - (default False) Whether IAM should be used to connect to Cloud SQL DB. + With using IAM password field should be empty string. * **sql_proxy_use_tcp** - (default False) If set to true, TCP is used to connect via proxy, otherwise UNIX sockets are used. * **sql_proxy_version** - Specific version of the proxy to download (for example @@ -839,11 +842,16 @@ def __init__( self.database_type = self.extras.get("database_type") self.use_proxy = self._get_bool(self.extras.get("use_proxy", "False")) self.use_ssl = self._get_bool(self.extras.get("use_ssl", "False")) + self.use_iam = self._get_bool(self.extras.get("use_iam", "False")) self.sql_proxy_use_tcp = self._get_bool(self.extras.get("sql_proxy_use_tcp", "False")) self.sql_proxy_version = self.extras.get("sql_proxy_version") self.sql_proxy_binary_path = sql_proxy_binary_path - self.user = self.cloudsql_connection.login - self.password = self.cloudsql_connection.password + if self.use_iam: + self.user = self._get_iam_db_login() + self.password = self._generate_login_token(service_account=self.cloudsql_connection.login) + else: + self.user = self.cloudsql_connection.login + self.password = self.cloudsql_connection.password self.public_ip = self.cloudsql_connection.host self.public_port = self.cloudsql_connection.port self.ssl_cert = ssl_cert @@ -1187,3 +1195,32 @@ def free_reserved_port(self) -> None: if self.reserved_tcp_socket: self.reserved_tcp_socket.close() self.reserved_tcp_socket = None + + def _get_iam_db_login(self) -> str: + """Get an IAM login for Cloud SQL database.""" + if not self.cloudsql_connection.login: + raise AirflowException("The login parameter needs to be set in connection") + + if self.database_type == "postgres": + return self.cloudsql_connection.login.split(".gserviceaccount.com")[0] + else: + return self.cloudsql_connection.login.split("@")[0] + + def _generate_login_token(self, service_account) -> str: + """Generate an IAM login token for Cloud SQL and return the token.""" + cmd = ["gcloud", "sql", "generate-login-token", f"--impersonate-service-account={service_account}"] + self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd)) + cloud_sql_hook = CloudSQLHook(api_version="v1", gcp_conn_id=self.gcp_conn_id) + + with cloud_sql_hook.provide_authorized_gcloud(): + proc = subprocess.run(cmd, capture_output=True) + + if proc.returncode != 0: + stderr_last_20_lines = "\n".join(proc.stderr.decode().strip().splitlines()[-20:]) + raise AirflowException( + f"Process exited with non-zero exit code. Exit code: {proc.returncode}. Error Details: " + f"{stderr_last_20_lines}" + ) + + auth_token = proc.stdout.decode().strip() + return auth_token diff --git a/providers/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_iam.py b/providers/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_iam.py new file mode 100644 index 0000000000000..4330137d977a3 --- /dev/null +++ b/providers/tests/system/google/cloud/cloud_sql/example_cloud_sql_query_iam.py @@ -0,0 +1,442 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that performs query in a Cloud SQL instance with IAM service account authentication. +""" + +from __future__ import annotations + +import logging +import os +import random +import string +from copy import deepcopy +from datetime import datetime +from pathlib import Path +from typing import Any, Iterable + +from googleapiclient import discovery + +from airflow import settings +from airflow.decorators import task +from airflow.models.connection import Connection +from airflow.models.dag import DAG +from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLHook +from airflow.providers.google.cloud.operators.cloud_sql import ( + CloudSQLCreateInstanceDatabaseOperator, + CloudSQLCreateInstanceOperator, + CloudSQLDeleteInstanceOperator, + CloudSQLExecuteQueryOperator, +) +from airflow.settings import Session +from airflow.utils.trigger_rule import TriggerRule + +from providers.tests.system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID +DAG_ID = "cloudsql_query_iam" +REGION = "us-central1" +HOME_DIR = Path.home() + +IS_COMPOSER = bool(os.environ.get("COMPOSER_ENVIRONMENT", "")) + +CLOUD_SQL_INSTANCE_NAME_TEMPLATE = f"{ENV_ID}-{DAG_ID}".replace("_", "-") +CLOUD_SQL_INSTANCE_CREATE_BODY_TEMPLATE: dict[str, Any] = { + "name": CLOUD_SQL_INSTANCE_NAME_TEMPLATE, + "settings": { + "tier": "db-custom-1-3840", + "dataDiskSizeGb": 30, + "pricingPlan": "PER_USE", + "ipConfiguration": {}, + }, + # For using a different database version please check the link below. + # https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1/SqlDatabaseVersion + "databaseVersion": "1.2.3", + "region": REGION, + "ipConfiguration": { + "ipv4Enabled": True, + "requireSsl": True, + "authorizedNetworks": [ + {"value": "0.0.0.0/0"}, + ], + }, +} + + +def ip_configuration() -> dict[str, Any]: + """Generates an ip configuration for a CloudSQL instance creation body""" + if IS_COMPOSER: + # Use connection to Cloud SQL instance via Private IP within the Cloud Composer's network. + return { + "ipv4Enabled": True, + "requireSsl": False, + "sslMode": "ENCRYPTED_ONLY", + "enablePrivatePathForGoogleCloudServices": True, + "privateNetwork": f"projects/{PROJECT_ID}/global/networks/default", + } + else: + # Use connection to Cloud SQL instance via Public IP from anywhere (mask 0.0.0.0/0). + # Consider specifying your network mask + # for allowing requests only from the trusted sources, not from anywhere. + return { + "ipv4Enabled": True, + "requireSsl": True, + "sslMode": "TRUSTED_CLIENT_CERTIFICATE_REQUIRED", + "authorizedNetworks": [ + {"value": "0.0.0.0/0"}, + ], + } + + +def cloud_sql_instance_create_body(database_provider: dict[str, Any]) -> dict[str, Any]: + """Generates a CloudSQL instance creation body""" + create_body: dict[str, Any] = deepcopy(CLOUD_SQL_INSTANCE_CREATE_BODY_TEMPLATE) + create_body["name"] = database_provider["cloud_sql_instance_name"] + create_body["databaseVersion"] = database_provider["database_version"] + create_body["settings"]["ipConfiguration"] = ip_configuration() + create_body["settings"]["databaseFlags"] = [ + {"name": database_provider["database_iam_flag_name"], "value": "on"} + ] + return create_body + + +CLOUD_SQL_DATABASE_NAME = "test_db" +CLOUD_SQL_USER = "test_user" +CLOUD_IAM_SA = os.environ.get("CLOUD_IAM_SA", "test_iam_sa") +CLOUD_SQL_IP_ADDRESS = "127.0.0.1" +CLOUD_SQL_PUBLIC_PORT = 5432 + +DB_PROVIDERS: Iterable[dict[str, str]] = ( + { + "database_type": "postgres", + "port": "5432", + "database_version": "POSTGRES_15", + "cloud_sql_instance_name": f"{CLOUD_SQL_INSTANCE_NAME_TEMPLATE}-postgres", + "database_iam_flag_name": "cloudsql.iam_authentication", + "cloud_sql_iam_sa": CLOUD_IAM_SA.split(".gserviceaccount.com")[0], + }, + { + "database_type": "mysql", + "port": "3306", + "database_version": "MYSQL_8_0", + "cloud_sql_instance_name": f"{CLOUD_SQL_INSTANCE_NAME_TEMPLATE}-mysql", + "database_iam_flag_name": "cloudsql_iam_authentication", + "cloud_sql_iam_sa": CLOUD_IAM_SA, + }, +) + + +def cloud_sql_database_create_body(instance: str) -> dict[str, Any]: + """Generates a CloudSQL database creation body""" + return { + "instance": instance, + "name": CLOUD_SQL_DATABASE_NAME, + "project": PROJECT_ID, + } + + +CLOUD_SQL_INSTANCE_NAME = "" +DATABASE_TYPE = "" # "postgres|mysql|mssql" + +# [START howto_operator_cloudsql_iam_connections] +CONNECTION_WITH_IAM_KWARGS = { + "conn_type": "gcpcloudsql", + "login": CLOUD_IAM_SA, + "password": "", + "host": CLOUD_SQL_IP_ADDRESS, + "port": CLOUD_SQL_PUBLIC_PORT, + "schema": CLOUD_SQL_DATABASE_NAME, + "extra": { + "database_type": DATABASE_TYPE, + "project_id": PROJECT_ID, + "location": REGION, + "instance": CLOUD_SQL_INSTANCE_NAME, + "use_proxy": "False", + "use_ssl": "True", + "use_iam": "True", + }, +} +# [END howto_operator_cloudsql_iam_connections] + +CONNECTION_PUBLIC_TCP_SSL_ID = f"{DAG_ID}_{ENV_ID}_tcp_ssl" + +PG_SQL = ["SELECT * FROM pg_catalog.pg_tables"] + +MYSQL_SQL = ["SHOW TABLES"] + +SSL_PATH = f"/{DAG_ID}/{ENV_ID}" +SSL_LOCAL_PATH_PREFIX = "/tmp" +SSL_COMPOSER_PATH_PREFIX = "/home/airflow/gcs/data" + +# The connections below are created using one of the standard approaches - via environment +# variables named AIRFLOW_CONN_* . The connections can also be created in the database +# of AIRFLOW (using command line or UI). + +postgres_kwargs = { + "user": "user", + "password": "password", + "public_ip": "public_ip", + "public_port": "public_port", + "database": "database", + "project_id": "project_id", + "location": "location", + "instance": "instance", + "client_cert_file": "client_cert_file", + "client_key_file": "client_key_file", + "server_ca_file": "server_ca_file", +} + +# Postgres: connect directly via TCP (SSL) +os.environ["AIRFLOW_CONN_PUBLIC_POSTGRES_TCP_SSL"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=postgres&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" + "use_ssl=True&" + "sslcert={client_cert_file}&" + "sslkey={client_key_file}&" + "sslrootcert={server_ca_file}".format(**postgres_kwargs) +) + +mysql_kwargs = { + "user": "user", + "password": "password", + "public_ip": "public_ip", + "public_port": "public_port", + "database": "database", + "project_id": "project_id", + "location": "location", + "instance": "instance", + "client_cert_file": "client_cert_file", + "client_key_file": "client_key_file", + "server_ca_file": "server_ca_file", +} + +# MySQL: connect directly via TCP (SSL) +os.environ["AIRFLOW_CONN_PUBLIC_MYSQL_TCP_SSL"] = ( + "gcpcloudsql://{user}:{password}@{public_ip}:{public_port}/{database}?" + "database_type=mysql&" + "project_id={project_id}&" + "location={location}&" + "instance={instance}&" + "use_proxy=False&" + "use_ssl=True&" + "sslcert={client_cert_file}&" + "sslkey={client_key_file}&" + "sslrootcert={server_ca_file}".format(**mysql_kwargs) +) + + +log = logging.getLogger(__name__) + +with DAG( + dag_id=DAG_ID, + start_date=datetime(2021, 1, 1), + schedule="@once", + catchup=False, + tags=["example", "cloudsql", "postgres"], +) as dag: + for db_provider in DB_PROVIDERS: + database_type: str = db_provider["database_type"] + cloud_sql_instance_name: str = db_provider["cloud_sql_instance_name"] + cloud_sql_iam_sa: str = db_provider["cloud_sql_iam_sa"] + + create_cloud_sql_instance = CloudSQLCreateInstanceOperator( + task_id=f"create_cloud_sql_instance_{database_type}", + project_id=PROJECT_ID, + instance=cloud_sql_instance_name, + body=cloud_sql_instance_create_body(database_provider=db_provider), + ) + + create_database = CloudSQLCreateInstanceDatabaseOperator( + task_id=f"create_database_{database_type}", + body=cloud_sql_database_create_body(instance=cloud_sql_instance_name), + instance=cloud_sql_instance_name, + ) + + @task(task_id=f"create_user_{database_type}") + def create_user(instance: str, service_account: str) -> None: + with discovery.build("sqladmin", "v1beta4") as service: + request = service.users().insert( + project=PROJECT_ID, + instance=instance, + body={ + "name": service_account, + "type": "CLOUD_IAM_SERVICE_ACCOUNT", + }, + ) + request.execute() + return None + + create_user_task = create_user(instance=cloud_sql_instance_name, service_account=cloud_sql_iam_sa) + + @task(task_id=f"get_ip_address_{database_type}") + def get_ip_address(instance: str) -> str | None: + """Returns a Cloud SQL instance IP address. + + If the test is running in Cloud Composer, the Private IP address is used, otherwise Public IP.""" + with discovery.build("sqladmin", "v1beta4") as service: + request = service.connect().get( + project=PROJECT_ID, + instance=instance, + fields="ipAddresses", + ) + response = request.execute() + for ip_item in response.get("ipAddresses", []): + if IS_COMPOSER: + if ip_item["type"] == "PRIVATE": + return ip_item["ipAddress"] + else: + if ip_item["type"] == "PRIMARY": + return ip_item["ipAddress"] + return None + + get_ip_address_task = get_ip_address(instance=cloud_sql_instance_name) + + conn_id = f"{CONNECTION_PUBLIC_TCP_SSL_ID}_{database_type}" + + @task(task_id=f"create_connection_{database_type}") + def create_connection( + connection_id: str, instance: str, db_type: str, ip_address: str, port: str + ) -> str | None: + session = settings.Session() + log.info("Removing connection %s if it exists", connection_id) + query = session.query(Connection).filter(Connection.conn_id == connection_id) + query.delete() + + connection: dict[str, Any] = deepcopy(CONNECTION_WITH_IAM_KWARGS) + connection["extra"]["instance"] = instance + connection["host"] = ip_address + connection["extra"]["database_type"] = db_type + connection["port"] = port + conn = Connection(conn_id=connection_id, **connection) + session.add(conn) + session.commit() + log.info("Connection created: '%s'", connection_id) + return connection_id + + create_connection_task = create_connection( + connection_id=conn_id, + instance=cloud_sql_instance_name, + db_type=database_type, + ip_address=get_ip_address_task, + port=db_provider["port"], + ) + + @task(task_id=f"create_ssl_certificates_{database_type}") + def create_ssl_certificate(instance: str, connection_id: str) -> dict[str, Any]: + hook = CloudSQLHook(api_version="v1", gcp_conn_id=connection_id) + certificate_name = f"test_cert_{''.join(random.choice(string.ascii_letters) for _ in range(8))}" + response = hook.create_ssl_certificate( + instance=instance, + body={"common_name": certificate_name}, + project_id=PROJECT_ID, + ) + return response + + create_ssl_certificate_task = create_ssl_certificate( + instance=cloud_sql_instance_name, connection_id=create_connection_task + ) + + @task(task_id=f"save_ssl_cert_locally_{database_type}") + def save_ssl_cert_locally(ssl_cert: dict[str, Any], db_type: str) -> dict[str, str]: + folder = SSL_COMPOSER_PATH_PREFIX if IS_COMPOSER else SSL_LOCAL_PATH_PREFIX + folder += f"/certs/{db_type}/{ssl_cert['operation']['name']}" + os.makedirs(folder, exist_ok=True) + _ssl_root_cert_path = f"{folder}/sslrootcert.pem" + _ssl_cert_path = f"{folder}/sslcert.pem" + _ssl_key_path = f"{folder}/sslkey.pem" + with open(_ssl_root_cert_path, "w") as ssl_root_cert_file: + ssl_root_cert_file.write(ssl_cert["serverCaCert"]["cert"]) + with open(_ssl_cert_path, "w") as ssl_cert_file: + ssl_cert_file.write(ssl_cert["clientCert"]["certInfo"]["cert"]) + with open(_ssl_key_path, "w") as ssl_key_file: + ssl_key_file.write(ssl_cert["clientCert"]["certPrivateKey"]) + return { + "sslrootcert": _ssl_root_cert_path, + "sslcert": _ssl_cert_path, + "sslkey": _ssl_key_path, + } + + save_ssl_cert_locally_task = save_ssl_cert_locally( + ssl_cert=create_ssl_certificate_task, db_type=database_type + ) + + task_id = f"example_cloud_sql_query_ssl_{database_type}" + ssl_server_cert_path = ( + f"{{{{ task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslrootcert'] }}}}" + ) + ssl_cert_path = ( + f"{{{{ task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslcert'] }}}}" + ) + ssl_key_path = f"{{{{ task_instance.xcom_pull('save_ssl_cert_locally_{database_type}')['sslkey'] }}}}" + + query_task = CloudSQLExecuteQueryOperator( + gcp_cloudsql_conn_id=conn_id, + task_id=task_id, + sql=PG_SQL if database_type == "postgres" else MYSQL_SQL, + ssl_client_cert=ssl_cert_path, + ssl_server_cert=ssl_server_cert_path, + ssl_client_key=ssl_key_path, + ) + + delete_instance = CloudSQLDeleteInstanceOperator( + task_id=f"delete_cloud_sql_instance_{database_type}", + project_id=PROJECT_ID, + instance=cloud_sql_instance_name, + trigger_rule=TriggerRule.ALL_DONE, + ) + + @task(task_id=f"delete_connection_{database_type}") + def delete_connection(connection_id: str) -> None: + session = Session() + log.info("Removing connection %s", connection_id) + query = session.query(Connection).filter(Connection.conn_id == connection_id) + query.delete() + session.commit() + + delete_connection_task = delete_connection(connection_id=conn_id) + + ( + # TEST SETUP + create_cloud_sql_instance + >> [create_database, create_user_task, get_ip_address_task] + >> create_connection_task + >> create_ssl_certificate_task + >> save_ssl_cert_locally_task + # TEST BODY + >> query_task + # TEST TEARDOWN + >> [delete_instance, delete_connection_task] + ) + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)