diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index c699f74756316..0dc953888724b 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -3287,8 +3287,13 @@ async def get_job_instance( self, project_id: str | None, job_id: str | None, session: ClientSession ) -> Job: """Get the specified job resource by job ID and project ID.""" - with await self.service_file_as_context() as f: - return Job(job_id=job_id, project=project_id, service_file=f, session=cast(Session, session)) + token = await self.get_token(session=session) + return Job( + job_id=job_id, + project=project_id, + token=token, + session=cast(Session, session), + ) async def get_job_status(self, job_id: str | None, project_id: str | None = None) -> dict[str, str]: async with ClientSession() as s: @@ -3532,11 +3537,11 @@ async def get_table_client( access to the specified project. :param session: aiohttp ClientSession """ - with await self.service_file_as_context() as file: - return Table_async( - dataset_name=dataset, - table_name=table_id, - project=project_id, - service_file=file, - session=cast(Session, session), - ) + token = await self.get_token(session=session) + return Table_async( + dataset_name=dataset, + table_name=table_id, + project=project_id, + token=token, + session=cast(Session, session), + ) diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index d904f00f4db10..f72f8c3f2287b 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -1398,5 +1398,8 @@ class GCSAsyncHook(GoogleBaseAsyncHook): async def get_storage_client(self, session: ClientSession) -> Storage: """Returns a Google Cloud Storage service object.""" - with await self.service_file_as_context() as file: - return Storage(service_file=file, session=cast(Session, session)) + token = await self.get_token(session=session) + return Storage( + token=token, + session=cast(Session, session), + ) diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 99120820e598c..d9e4e893b1d0e 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -18,6 +18,7 @@ """This module contains a Google Cloud API base hook.""" from __future__ import annotations +import datetime import functools import json import logging @@ -35,6 +36,7 @@ import requests import tenacity from asgiref.sync import sync_to_async +from gcloud.aio.auth.token import Token from google.api_core.exceptions import Forbidden, ResourceExhausted, TooManyRequests from google.auth import _cloud_sdk, compute_engine # type: ignore[attr-defined] from google.auth.environment_vars import CLOUD_SDK_CONFIG_DIR, CREDENTIALS @@ -43,6 +45,7 @@ from googleapiclient import discovery from googleapiclient.errors import HttpError from googleapiclient.http import MediaIoBaseDownload, build_http, set_user_agent +from requests import Session from airflow import version from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning @@ -56,7 +59,9 @@ from airflow.utils.process_utils import patch_environ if TYPE_CHECKING: + from aiohttp import ClientSession from google.api_core.gapic_v1.client_info import ClientInfo + from google.auth.credentials import Credentials log = logging.getLogger(__name__) @@ -623,6 +628,51 @@ def test_connection(self): return status, message +class _CredentialsToken(Token): + """A token implementation which makes Google credentials objects accessible to [gcloud-aio](https://talkiq.github.io/gcloud-aio/) clients. + + This class allows us to create token instances from credentials objects and thus supports a variety of use cases for Google + credentials in Airflow (i.e. impersonation chain). By relying on a existing credentials object we leverage functionality provided by the GoogleBaseHook + for generating credentials objects. + """ + + def __init__( + self, + credentials: Credentials, + *, + project: str | None = None, + session: ClientSession | None = None, + ) -> None: + super().__init__(session=cast(Session, session)) + self.credentials = credentials + self.project = project + + @classmethod + async def from_hook( + cls, + hook: GoogleBaseHook, + *, + session: ClientSession | None = None, + ) -> _CredentialsToken: + credentials, project = hook.get_credentials_and_project_id() + return cls( + credentials=credentials, + project=project, + session=session, + ) + + async def get_project(self) -> str | None: + return self.project + + async def acquire_access_token(self, timeout: int = 10) -> None: + await sync_to_async(self.credentials.refresh)(google.auth.transport.requests.Request()) + + self.access_token = cast(str, self.credentials.token) + self.access_token_duration = 3600 + self.access_token_acquired_at = datetime.datetime.utcnow() + self.acquiring = None + + class GoogleBaseAsyncHook(BaseHook): """GoogleBaseAsyncHook inherits from BaseHook class, run on the trigger worker.""" @@ -639,6 +689,12 @@ async def get_sync_hook(self) -> Any: self._sync_hook = await sync_to_async(self.sync_hook_class)(**self._hook_kwargs) return self._sync_hook + async def get_token(self, *, session: ClientSession | None = None) -> _CredentialsToken: + """Returns a Token instance for use in [gcloud-aio](https://talkiq.github.io/gcloud-aio/) clients.""" + sync_hook = await self.get_sync_hook() + return await _CredentialsToken.from_hook(sync_hook, session=session) + async def service_file_as_context(self) -> Any: + """This is the async equivalent of the non-async GoogleBaseHook's `provide_gcp_credential_file_as_context` method.""" sync_hook = await self.get_sync_hook() return await sync_to_async(sync_hook.provide_gcp_credential_file_as_context)() diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 47fc20464749e..730fa6734ad9e 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -22,6 +22,7 @@ from unittest import mock from unittest.mock import AsyncMock +import google.auth import pytest from gcloud.aio.bigquery import Job, Table as Table_async from google.api_core import page_iterator @@ -2143,8 +2144,12 @@ def get_credentials_and_project_id(self): class TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass): @pytest.mark.db_test @pytest.mark.asyncio + @mock.patch("google.auth.default") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession") - async def test_get_job_instance(self, mock_session): + async def test_get_job_instance(self, mock_session, mock_auth_default): + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_credentials.token = "ACCESS_TOKEN" + mock_auth_default.return_value = (mock_credentials, PROJECT_ID) hook = BigQueryAsyncHook() result = await hook.get_job_instance(project_id=PROJECT_ID, job_id=JOB_ID, session=mock_session) assert isinstance(result, Job) @@ -2315,10 +2320,13 @@ def test_convert_to_float_if_possible(self, test_input, expected): @pytest.mark.db_test @pytest.mark.asyncio + @mock.patch("google.auth.default") @mock.patch("aiohttp.client.ClientSession") - async def test_get_table_client(self, mock_session): + async def test_get_table_client(self, mock_session, mock_auth_default): """Test get_table_client async function and check whether the return value is a Table instance object""" + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_auth_default.return_value = (mock_credentials, PROJECT_ID) hook = BigQueryTableAsyncHook() result = await hook.get_table_client( dataset=DATASET_ID, project_id=PROJECT_ID, table_id=TABLE_ID, session=mock_session diff --git a/tests/providers/google/common/hooks/test_base_google.py b/tests/providers/google/common/hooks/test_base_google.py index bd4342ec66d78..f4b71d7449ffe 100644 --- a/tests/providers/google/common/hooks/test_base_google.py +++ b/tests/providers/google/common/hooks/test_base_google.py @@ -26,6 +26,7 @@ from unittest.mock import patch import google.auth +import google.auth.compute_engine import pytest import tenacity from google.auth.environment_vars import CREDENTIALS @@ -874,3 +875,95 @@ def test_should_fallback_when_empty_string_in_env_var(self): instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default") assert isinstance(instance.num_retries, int) assert 5 == instance.num_retries + + +class TestCredentialsToken: + @pytest.mark.asyncio + async def test_get_project(self): + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID) + assert await token.get_project() == PROJECT_ID + + @pytest.mark.asyncio + async def test_get(self): + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_credentials.token = "ACCESS_TOKEN" + token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID) + assert await token.get() == "ACCESS_TOKEN" + mock_credentials.refresh.assert_called_once() + + @pytest.mark.asyncio + @mock.patch(f"{MODULE_NAME}.get_credentials_and_project_id", return_value=("CREDENTIALS", "PROJECT_ID")) + async def test_from_hook(self, get_creds_and_project, monkeypatch): + monkeypatch.setenv( + "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT", + "google-cloud-platform://", + ) + instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default") + token = await hook._CredentialsToken.from_hook(instance) + assert token.credentials == "CREDENTIALS" + assert token.project == "PROJECT_ID" + + +class TestGoogleBaseAsyncHook: + @pytest.mark.asyncio + @mock.patch("google.auth.default") + async def test_get_token(self, mock_auth_default, monkeypatch) -> None: + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_credentials.token = "ACCESS_TOKEN" + mock_auth_default.return_value = (mock_credentials, "PROJECT_ID") + monkeypatch.setenv( + "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT", + "google-cloud-platform://?project=CONN_PROJECT_ID", + ) + + instance = hook.GoogleBaseAsyncHook(gcp_conn_id="google_cloud_default") + instance.sync_hook_class = hook.GoogleBaseHook + token = await instance.get_token() + assert await token.get_project() == "CONN_PROJECT_ID" + assert await token.get() == "ACCESS_TOKEN" + mock_credentials.refresh.assert_called_once() + + @pytest.mark.asyncio + @mock.patch("google.auth.default") + async def test_get_token_impersonation(self, mock_auth_default, monkeypatch, requests_mock) -> None: + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_credentials.token = "ACCESS_TOKEN" + mock_auth_default.return_value = (mock_credentials, "PROJECT_ID") + monkeypatch.setenv( + "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT", + "google-cloud-platform://?project=CONN_PROJECT_ID", + ) + requests_mock.post( + "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com:generateAccessToken", + text='{"accessToken": "IMPERSONATED_ACCESS_TOKEN", "expireTime": "2014-10-02T15:01:23Z"}', + ) + + instance = hook.GoogleBaseAsyncHook( + gcp_conn_id="google_cloud_default", + impersonation_chain="SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com", + ) + instance.sync_hook_class = hook.GoogleBaseHook + token = await instance.get_token() + assert await token.get_project() == "CONN_PROJECT_ID" + assert await token.get() == "IMPERSONATED_ACCESS_TOKEN" + + @pytest.mark.asyncio + @mock.patch("google.auth.default") + async def test_get_token_impersonation_conn(self, mock_auth_default, monkeypatch, requests_mock) -> None: + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_auth_default.return_value = (mock_credentials, "PROJECT_ID") + monkeypatch.setenv( + "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT", + "google-cloud-platform://?project=CONN_PROJECT_ID&impersonation_chain=SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com", + ) + requests_mock.post( + "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com:generateAccessToken", + text='{"accessToken": "IMPERSONATED_ACCESS_TOKEN", "expireTime": "2014-10-02T15:01:23Z"}', + ) + + instance = hook.GoogleBaseAsyncHook(gcp_conn_id="google_cloud_default") + instance.sync_hook_class = hook.GoogleBaseHook + token = await instance.get_token() + assert await token.get_project() == "CONN_PROJECT_ID" + assert await token.get() == "IMPERSONATED_ACCESS_TOKEN"