From cace85aecd4d9cb29849fb1823032b5b855f1e4e Mon Sep 17 00:00:00 2001 From: Pankaj Date: Fri, 8 Dec 2023 13:32:55 +0530 Subject: [PATCH] Fix mypy --- astronomer/providers/google/cloud/hooks/dataproc.py | 11 ++++++++--- .../providers/snowflake/hooks/snowflake_sql_api.py | 10 +++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/astronomer/providers/google/cloud/hooks/dataproc.py b/astronomer/providers/google/cloud/hooks/dataproc.py index 258813566..5cce09f84 100644 --- a/astronomer/providers/google/cloud/hooks/dataproc.py +++ b/astronomer/providers/google/cloud/hooks/dataproc.py @@ -5,7 +5,7 @@ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook from google.api_core import gapic_v1 from google.api_core.client_options import ClientOptions -from google.api_core.retry import Retry +from google.api_core import retry_async as retries from google.cloud.dataproc_v1 import ( ClusterControllerAsyncClient, Job, @@ -13,6 +13,11 @@ ) from google.cloud.dataproc_v1.types import clusters +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault] +except AttributeError: + OptionalRetry = Union[retries.AsyncRetry, object] + JobType = Union[Job, Any] @@ -68,7 +73,7 @@ async def get_cluster( region: str, cluster_name: str, project_id: str, - retry: Union[Retry, gapic_v1.method._MethodDefault] = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), ) -> clusters.Cluster: """ @@ -98,7 +103,7 @@ async def get_job( timeout: float = 5, region: Optional[str] = None, location: Optional[str] = None, - retry: Union[Retry, gapic_v1.method._MethodDefault] = gapic_v1.method.DEFAULT, + retry: OptionalRetry = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), ) -> JobType: """ diff --git a/astronomer/providers/snowflake/hooks/snowflake_sql_api.py b/astronomer/providers/snowflake/hooks/snowflake_sql_api.py index d8ebe3c28..535b7a295 100644 --- a/astronomer/providers/snowflake/hooks/snowflake_sql_api.py +++ b/astronomer/providers/snowflake/hooks/snowflake_sql_api.py @@ -3,7 +3,7 @@ import uuid from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import Any import aiohttp import requests @@ -140,8 +140,8 @@ def execute_query( try: response.raise_for_status() except requests.exceptions.HTTPError as e: # pragma: no cover - if TYPE_CHECKING: - assert e.response is not None + if e.response is None: + raise raise AirflowException( f"Response: {e.response.content.decode()}, " f"Status Code: {e.response.status_code}" ) # pragma: no cover @@ -205,8 +205,8 @@ def check_query_output(self, query_ids: list[str]) -> None: response.raise_for_status() self.log.info(response.json()) except requests.exceptions.HTTPError as e: - if TYPE_CHECKING: - assert e.response is not None + if e.response is None: + raise raise AirflowException( f"Response: {e.response.content.decode()}, Status Code: {e.response.status_code}" )