Skip to content

Commit

Permalink
Apply PROVIDE_PROJECT_ID mypy workaround across Google provider (apac…
Browse files Browse the repository at this point in the history
…he#39129)

There is a simple workaround implemented several years ago for Google
provider `project_id` default value being PROVIDE_PROJECT_ID that
satisfy mypy checks for project_id being set. They way how
`fallback_to_default_project_id` works is that across all the
providers the project_id is actually set, even if technically
it's default value is set to None.

This is similar typing workaround as we use for NEW_SESSION in the
core of Airflow.

The workaround has not been applied consistently across all the
google provider code and occasionally it causes MyPy complaining
when newer version of a google library introduces more strict
type checking and expects the provider_id to be set.

This PR applies the workaround across all the Google provider
code.

This is - generally speaking a no-op operation. Nothing changes,
except MyPy being aware that the project_id is actually going to
be set even if it is technically set to None.
  • Loading branch information
potiuk authored Apr 19, 2024
1 parent 2674a69 commit 90acbfb
Show file tree
Hide file tree
Showing 64 changed files with 439 additions and 362 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def list_table_specs(
self,
dataset_id: str,
location: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
filter_: str | None = None,
page_size: int | None = None,
retry: Retry | _MethodDefault = DEFAULT,
Expand Down
65 changes: 36 additions & 29 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@
from airflow.providers.google.cloud.utils.bigquery import bq_cast
from airflow.providers.google.cloud.utils.credentials_provider import _get_scopes
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
get_field,
)

try:
from airflow.utils.hashlib_wrapper import md5
Expand Down Expand Up @@ -198,7 +203,7 @@ def get_service(self) -> Resource:
http_authorized = self._authorize()
return build("bigquery", "v2", http=http_authorized, cache_discovery=False)

def get_client(self, project_id: str | None = None, location: str | None = None) -> Client:
def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None) -> Client:
"""Get an authenticated BigQuery Client.
:param project_id: Project ID for the project which the client acts on behalf of.
Expand Down Expand Up @@ -250,7 +255,7 @@ def get_records(self, sql, parameters=None):
@staticmethod
def _resolve_table_reference(
table_resource: dict[str, Any],
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
dataset_id: str | None = None,
table_id: str | None = None,
) -> dict[str, Any]:
Expand Down Expand Up @@ -360,7 +365,7 @@ def table_partition_exists(
@GoogleBaseHook.fallback_to_default_project_id
def create_empty_table(
self,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
dataset_id: str | None = None,
table_id: str | None = None,
table_resource: dict[str, Any] | None = None,
Expand Down Expand Up @@ -474,7 +479,7 @@ def create_empty_table(
def create_empty_dataset(
self,
dataset_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
dataset_reference: dict[str, Any] | None = None,
exists_ok: bool = True,
Expand Down Expand Up @@ -536,7 +541,7 @@ def create_empty_dataset(
def get_dataset_tables(
self,
dataset_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
max_results: int | None = None,
retry: Retry = DEFAULT_RETRY,
) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -565,7 +570,7 @@ def get_dataset_tables(
def delete_dataset(
self,
dataset_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
delete_contents: bool = False,
retry: Retry = DEFAULT_RETRY,
) -> None:
Expand Down Expand Up @@ -614,7 +619,7 @@ def create_external_table(
description: str | None = None,
encryption_configuration: dict | None = None,
location: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> Table:
"""Create an external table in the dataset with data from Google Cloud Storage.
Expand Down Expand Up @@ -750,7 +755,7 @@ def update_table(
fields: list[str] | None = None,
dataset_id: str | None = None,
table_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> dict[str, Any]:
"""Change some fields of a table.
Expand Down Expand Up @@ -796,7 +801,7 @@ def patch_table(
self,
dataset_id: str,
table_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
description: str | None = None,
expiration_time: int | None = None,
external_data_configuration: dict | None = None,
Expand Down Expand Up @@ -953,7 +958,7 @@ def update_dataset(
fields: Sequence[str],
dataset_resource: dict[str, Any],
dataset_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
retry: Retry = DEFAULT_RETRY,
) -> Dataset:
"""Change some fields of a dataset.
Expand Down Expand Up @@ -999,7 +1004,9 @@ def update_dataset(
),
category=AirflowProviderDeprecationWarning,
)
def patch_dataset(self, dataset_id: str, dataset_resource: dict, project_id: str | None = None) -> dict:
def patch_dataset(
self, dataset_id: str, dataset_resource: dict, project_id: str = PROVIDE_PROJECT_ID
) -> dict:
"""Patches information in an existing dataset.
It only replaces fields that are provided in the submitted dataset resource.
Expand Down Expand Up @@ -1047,7 +1054,7 @@ def patch_dataset(self, dataset_id: str, dataset_resource: dict, project_id: str
def get_dataset_tables_list(
self,
dataset_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
table_prefix: str | None = None,
max_results: int | None = None,
) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -1084,7 +1091,7 @@ def get_dataset_tables_list(
@GoogleBaseHook.fallback_to_default_project_id
def get_datasets_list(
self,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
include_all: bool = False,
filter_: str | None = None,
max_results: int | None = None,
Expand Down Expand Up @@ -1134,7 +1141,7 @@ def get_datasets_list(
return datasets_list

@GoogleBaseHook.fallback_to_default_project_id
def get_dataset(self, dataset_id: str, project_id: str | None = None) -> Dataset:
def get_dataset(self, dataset_id: str, project_id: str = PROVIDE_PROJECT_ID) -> Dataset:
"""Fetch the dataset referenced by *dataset_id*.
:param dataset_id: The BigQuery Dataset ID
Expand All @@ -1158,7 +1165,7 @@ def run_grant_dataset_view_access(
view_dataset: str,
view_table: str,
view_project: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> dict[str, Any]:
"""Grant authorized view access of a dataset to a view table.
Expand Down Expand Up @@ -1210,7 +1217,7 @@ def run_grant_dataset_view_access(

@GoogleBaseHook.fallback_to_default_project_id
def run_table_upsert(
self, dataset_id: str, table_resource: dict[str, Any], project_id: str | None = None
self, dataset_id: str, table_resource: dict[str, Any], project_id: str = PROVIDE_PROJECT_ID
) -> dict[str, Any]:
"""Update a table if it exists, otherwise create a new one.
Expand Down Expand Up @@ -1267,7 +1274,7 @@ def delete_table(
self,
table_id: str,
not_found_ok: bool = True,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> None:
"""Delete an existing table from the dataset.
Expand Down Expand Up @@ -1334,7 +1341,7 @@ def list_rows(
selected_fields: list[str] | str | None = None,
page_token: str | None = None,
start_index: int | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
retry: Retry = DEFAULT_RETRY,
return_iterator: bool = False,
Expand Down Expand Up @@ -1387,7 +1394,7 @@ def list_rows(
return list(iterator)

@GoogleBaseHook.fallback_to_default_project_id
def get_schema(self, dataset_id: str, table_id: str, project_id: str | None = None) -> dict:
def get_schema(self, dataset_id: str, table_id: str, project_id: str = PROVIDE_PROJECT_ID) -> dict:
"""Get the schema for a given dataset and table.
.. seealso:: https://cloud.google.com/bigquery/docs/reference/v2/tables#resource
Expand All @@ -1409,7 +1416,7 @@ def update_table_schema(
include_policy_tags: bool,
dataset_id: str,
table_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> dict[str, Any]:
"""Update fields within a schema for a given dataset and table.
Expand Down Expand Up @@ -1502,7 +1509,7 @@ def _remove_policy_tags(schema: list[dict[str, Any]]):
def poll_job_complete(
self,
job_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
retry: Retry = DEFAULT_RETRY,
) -> bool:
Expand Down Expand Up @@ -1532,7 +1539,7 @@ def cancel_query(self) -> None:
def cancel_job(
self,
job_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
) -> None:
"""Cancel a job and wait for cancellation to complete.
Expand Down Expand Up @@ -1576,7 +1583,7 @@ def cancel_job(
def get_job(
self,
job_id: str,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
"""Retrieve a BigQuery job.
Expand Down Expand Up @@ -1607,7 +1614,7 @@ def insert_job(
self,
configuration: dict,
job_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
location: str | None = None,
nowait: bool = False,
retry: Retry = DEFAULT_RETRY,
Expand Down Expand Up @@ -3304,7 +3311,7 @@ async def get_job_instance(
)

async def _get_job(
self, job_id: str | None, project_id: str | None = None, location: str | None = None
self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None
) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
"""
Get BigQuery job by its ID, project ID and location.
Expand Down Expand Up @@ -3347,7 +3354,7 @@ def _get_job_sync(self, job_id, project_id, location):
return hook.get_job(job_id=job_id, project_id=project_id, location=location)

async def get_job_status(
self, job_id: str | None, project_id: str | None = None, location: str | None = None
self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None
) -> dict[str, str]:
job = await self._get_job(job_id=job_id, project_id=project_id, location=location)
if job.state == "DONE":
Expand All @@ -3359,7 +3366,7 @@ async def get_job_status(
async def get_job_output(
self,
job_id: str | None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> dict[str, Any]:
"""Get the BigQuery job output for a given job ID asynchronously."""
async with ClientSession() as session:
Expand All @@ -3372,7 +3379,7 @@ async def create_job_for_partition_get(
self,
dataset_id: str | None,
table_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
):
"""Create a new job and get the job_id using gcloud-aio."""
async with ClientSession() as session:
Expand Down
9 changes: 7 additions & 2 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@
from airflow.providers.google.cloud.hooks.secret_manager import (
GoogleCloudSecretManagerHook,
)
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
get_field,
)
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -510,7 +515,7 @@ def __init__(
path_prefix: str,
instance_specification: str,
gcp_conn_id: str = "google_cloud_default",
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
sql_proxy_version: str | None = None,
sql_proxy_binary_path: str | None = None,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
)

if TYPE_CHECKING:
from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
Expand Down Expand Up @@ -504,7 +508,7 @@ def operations_contain_expected_statuses(
class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
"""Asynchronous hook for Google Storage Transfer Service."""

def __init__(self, project_id: str | None = None, **kwargs: Any) -> None:
def __init__(self, project_id: str = PROVIDE_PROJECT_ID, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self._client: StorageTransferServiceAsyncClient | None = None
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/hooks/compute_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook
from airflow.providers.google.cloud.hooks.os_login import OSLoginHook
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.utils.types import NOTSET, ArgNotSet

Expand Down Expand Up @@ -109,7 +110,7 @@ def __init__(
instance_name: str | None = None,
zone: str | None = None,
user: str | None = "root",
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
hostname: str | None = None,
use_internal_ip: bool = False,
use_iap_tunnel: bool = False,
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/google/cloud/hooks/dataplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@

from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import (
PROVIDE_PROJECT_ID,
GoogleBaseAsyncHook,
GoogleBaseHook,
)

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand Down Expand Up @@ -665,7 +669,7 @@ def wait_for_data_scan_job(
self,
data_scan_id: str,
job_id: str | None = None,
project_id: str | None = None,
project_id: str = PROVIDE_PROJECT_ID,
region: str | None = None,
wait_time: int = 10,
result_timeout: float | None = None,
Expand Down
Loading

0 comments on commit 90acbfb

Please sign in to comment.