diff --git a/astronomer/providers/google/cloud/sensors/bigquery.py b/astronomer/providers/google/cloud/sensors/bigquery.py index 73f81e595..e4c27eb35 100644 --- a/astronomer/providers/google/cloud/sensors/bigquery.py +++ b/astronomer/providers/google/cloud/sensors/bigquery.py @@ -2,43 +2,16 @@ from __future__ import annotations import warnings -from datetime import timedelta from typing import Any from airflow.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensor -from astronomer.providers.google.cloud.triggers.bigquery import ( - BigQueryTableExistenceTrigger, -) -from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception -from astronomer.providers.utils.typing_compat import Context - class BigQueryTableExistenceSensorAsync(BigQueryTableExistenceSensor): """ - Checks for the existence of a table in Google Big Query. - - :param project_id: The Google cloud project in which to look for the table. - The connection supplied to the hook must provide - access to the specified project. - :param dataset_id: The name of the dataset in which to look for the table. - storage bucket. - :param table_id: The name of the table to check the existence of. - :param gcp_conn_id: The connection ID used to connect to Google Cloud. - :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. - This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. - :param delegate_to: (Removed in apache-airflow-providers-google release 10.0.0, use impersonation_chain instead) - The account to impersonate using domain-wide delegation of authority, if any. For this to work, the service - account making the request must have domain-wide delegation enabled. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - :param polling_interval: The interval in seconds to wait between checks table existence. + This class is deprecated. + Please use :class: `~airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor` + and set `deferrable` param to `True` instead. """ def __init__( @@ -47,7 +20,16 @@ def __init__( polling_interval: float = 5.0, **kwargs: Any, ) -> None: - super().__init__(**kwargs) + warnings.warn( + ( + "This class is deprecated." + "Please use `airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor`" + "and set `deferrable` param to `True` instead." + ), + DeprecationWarning, + stacklevel=2, + ) + super().__init__(deferrable=True, **kwargs) # TODO: Remove once deprecated if polling_interval: self.poke_interval = polling_interval @@ -58,36 +40,3 @@ def __init__( stacklevel=2, ) self.gcp_conn_id = gcp_conn_id - - def execute(self, context: Context) -> None: - """Airflow runs this method on the worker and defers using the trigger.""" - hook_params = {"impersonation_chain": self.impersonation_chain} - if hasattr(self, "delegate_to"): # pragma: no cover - hook_params["delegate_to"] = self.delegate_to - - if not poke(self, context): - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=BigQueryTableExistenceTrigger( - dataset_id=self.dataset_id, - table_id=self.table_id, - project_id=self.project_id, - poke_interval=self.poke_interval, - gcp_conn_id=self.gcp_conn_id, - hook_params=hook_params, - ), - method_name="execute_complete", - ) - - def execute_complete(self, context: dict[str, Any], event: dict[str, str] | None = None) -> str: # type: ignore[return] - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}" - self.log.info("Sensor checks existence of table: %s", table_uri) - if event: - if event["status"] == "success": - return event["message"] - raise_error_or_skip_exception(self.soft_fail, event["message"]) diff --git a/astronomer/providers/google/cloud/triggers/bigquery.py b/astronomer/providers/google/cloud/triggers/bigquery.py index a8e934299..c70363d07 100644 --- a/astronomer/providers/google/cloud/triggers/bigquery.py +++ b/astronomer/providers/google/cloud/triggers/bigquery.py @@ -1,14 +1,8 @@ +from __future__ import annotations + import asyncio -from typing import ( - Any, - AsyncIterator, - Dict, - Optional, - Sequence, - SupportsAbs, - Tuple, - Union, -) +import warnings +from typing import Any, AsyncIterator, Sequence, SupportsAbs from aiohttp import ClientSession from aiohttp.client_exceptions import ClientResponseError @@ -38,12 +32,12 @@ class BigQueryInsertJobTrigger(BaseTrigger): def __init__( self, conn_id: str, - job_id: Optional[str], - project_id: Optional[str], - dataset_id: Optional[str] = None, - table_id: Optional[str] = None, - delegate_to: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + job_id: str | None, + project_id: str | None, + dataset_id: str | None = None, + table_id: str | None = None, + delegate_to: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, poll_interval: float = 4.0, ): super().__init__() @@ -58,7 +52,7 @@ def __init__( self.impersonation_chain = impersonation_chain self.poll_interval = poll_interval - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes BigQueryInsertJobTrigger arguments and classpath.""" return ( "astronomer.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger", @@ -74,7 +68,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """Gets current job execution status and yields a TriggerEvent""" hook = self._get_async_hook() while True: @@ -113,7 +107,7 @@ def _get_async_hook(self) -> BigQueryHookAsync: class BigQueryCheckTrigger(BigQueryInsertJobTrigger): """BigQueryCheckTrigger run on the trigger worker""" - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes BigQueryCheckTrigger arguments and classpath.""" return ( "astronomer.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger", @@ -128,7 +122,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """Gets current job execution status and yields a TriggerEvent""" hook = self._get_async_hook() while True: @@ -173,7 +167,7 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: class BigQueryGetDataTrigger(BigQueryInsertJobTrigger): """BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class""" - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes BigQueryInsertJobTrigger arguments and classpath.""" return ( "astronomer.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger", @@ -189,7 +183,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """Gets current job execution status and yields a TriggerEvent with response data""" hook = self._get_async_hook() while True: @@ -248,16 +242,16 @@ def __init__( conn_id: str, first_job_id: str, second_job_id: str, - project_id: Optional[str], + project_id: str | None, table: str, - metrics_thresholds: Dict[str, int], - date_filter_column: Optional[str] = "ds", + metrics_thresholds: dict[str, int], + date_filter_column: str | None = "ds", days_back: SupportsAbs[int] = -7, ratio_formula: str = "max_over_min", ignore_zero: bool = True, - dataset_id: Optional[str] = None, - table_id: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + dataset_id: str | None = None, + table_id: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, poll_interval: float = 4.0, ): super().__init__( @@ -280,7 +274,7 @@ def __init__( self.ratio_formula = ratio_formula self.ignore_zero = ignore_zero - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes BigQueryCheckTrigger arguments and classpath.""" return ( "astronomer.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger", @@ -298,7 +292,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """Gets current job execution status and yields a TriggerEvent""" hook = self._get_async_hook() while True: @@ -325,14 +319,14 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # If empty list, then no records are available if not first_records: - first_job_row: Optional[str] = None + first_job_row: str | None = None else: # Extract only first record from the query results first_job_row = first_records.pop(0) # If empty list, then no records are available if not second_records: - second_job_row: Optional[str] = None + second_job_row: str | None = None else: # Extract only first record from the query results second_job_row = second_records.pop(0) @@ -391,13 +385,13 @@ def __init__( self, conn_id: str, sql: str, - pass_value: Union[int, float, str], - job_id: Optional[str], - project_id: Optional[str], + pass_value: int | (float | str), + job_id: str | None, + project_id: str | None, tolerance: Any = None, - dataset_id: Optional[str] = None, - table_id: Optional[str] = None, - impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + dataset_id: str | None = None, + table_id: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, poll_interval: float = 4.0, ): super().__init__( @@ -413,7 +407,7 @@ def __init__( self.pass_value = pass_value self.tolerance = tolerance - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes BigQueryValueCheckTrigger arguments and classpath.""" return ( "astronomer.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger", @@ -430,7 +424,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: }, ) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """Gets current job execution status and yields a TriggerEvent""" hook = self._get_async_hook() while True: @@ -462,6 +456,9 @@ class BigQueryTableExistenceTrigger(BaseTrigger): """ Initialise the BigQuery Table Existence Trigger with needed parameters + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger` instead + :param project_id: Google Cloud Project where the job is running :param dataset_id: The dataset ID of the requested table. :param table_id: The table ID of the requested table. @@ -476,9 +473,18 @@ def __init__( dataset_id: str, table_id: str, gcp_conn_id: str, - hook_params: Dict[str, Any], + hook_params: dict[str, Any], poke_interval: float = 4.0, ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) + self.dataset_id = dataset_id self.project_id = project_id self.table_id = table_id @@ -486,7 +492,7 @@ def __init__( self.poke_interval = poke_interval self.hook_params = hook_params - def serialize(self) -> Tuple[str, Dict[str, Any]]: + def serialize(self) -> tuple[str, dict[str, Any]]: """Serializes BigQueryTableExistenceTrigger arguments and classpath.""" return ( "astronomer.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger", @@ -503,7 +509,7 @@ def serialize(self) -> Tuple[str, Dict[str, Any]]: def _get_async_hook(self) -> BigQueryTableHookAsync: return BigQueryTableHookAsync(gcp_conn_id=self.gcp_conn_id, **self.hook_params) - async def run(self) -> AsyncIterator["TriggerEvent"]: + async def run(self) -> AsyncIterator[TriggerEvent]: """Will run until the table exists in the Google Big Query.""" while True: try: diff --git a/tests/google/cloud/sensors/test_bigquery.py b/tests/google/cloud/sensors/test_bigquery.py index ffab95d23..48ae0ce23 100644 --- a/tests/google/cloud/sensors/test_bigquery.py +++ b/tests/google/cloud/sensors/test_bigquery.py @@ -1,78 +1,19 @@ -from unittest import mock +from airflow.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensor -import pytest -from airflow.exceptions import AirflowException, TaskDeferred - -from astronomer.providers.google.cloud.sensors.bigquery import ( - BigQueryTableExistenceSensorAsync, -) -from astronomer.providers.google.cloud.triggers.bigquery import ( - BigQueryTableExistenceTrigger, -) +from astronomer.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensorAsync PROJECT_ID = "test-astronomer-airflow-providers" DATASET_NAME = "test-astro_dataset" TABLE_NAME = "test-partitioned_table" -MODULE = "astronomer.providers.google.cloud.sensors.bigquery" - class TestBigQueryTableExistenceSensorAsync: - SENSOR = BigQueryTableExistenceSensorAsync( - task_id="bq_check_table", - project_id=PROJECT_ID, - dataset_id=DATASET_NAME, - table_id=TABLE_NAME, - ) - - @mock.patch(f"{MODULE}.BigQueryTableExistenceSensorAsync.defer") - @mock.patch(f"{MODULE}.BigQueryTableExistenceSensorAsync.poke", return_value=True) - def test_big_query_table_existence_sensor_async_finish_before_deferred( - self, mock_poke, mock_defer, context - ): - """Assert task is not deferred when it receives a finish status before deferring""" - self.SENSOR.execute(context) - assert not mock_defer.called - - @mock.patch(f"{MODULE}.BigQueryTableExistenceSensorAsync.poke", return_value=False) - def test_big_query_table_existence_sensor_async(self, context): - """ - Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired - when the BigQueryTableExistenceSensorAsync is executed. - """ - - with pytest.raises(TaskDeferred) as exc: - self.SENSOR.execute(context) - assert isinstance( - exc.value.trigger, BigQueryTableExistenceTrigger - ), "Trigger is not a BigQueryTableExistenceTrigger" - - def test_big_query_table_existence_sensor_async_execute_failure(self, context): - """Tests that an AirflowException is raised in case of error event""" - - with pytest.raises(AirflowException): - self.SENSOR.execute_complete( - context=context, event={"status": "error", "message": "test failure message"} - ) - - def test_big_query_table_existence_sensor_async_execute_complete(self): - """Asserts that logging occurs as expected""" - - table_uri = f"{PROJECT_ID}:{DATASET_NAME}.{TABLE_NAME}" - with mock.patch.object(self.SENSOR.log, "info") as mock_log_info: - self.SENSOR.execute_complete( - context=None, event={"status": "success", "message": "Job completed"} - ) - mock_log_info.assert_called_with("Sensor checks existence of table: %s", table_uri) - - def test_poll_interval_deprecation_warning(self): - """Test DeprecationWarning for BigQueryTableExistenceSensorAsync by setting param poll_interval""" - # TODO: Remove once deprecated - with pytest.warns(expected_warning=DeprecationWarning): - BigQueryTableExistenceSensorAsync( - task_id="task-id", - project_id=PROJECT_ID, - dataset_id=DATASET_NAME, - table_id=TABLE_NAME, - polling_interval=5.0, - ) + def test_init(self): + task = BigQueryTableExistenceSensorAsync( + task_id="bq_check_table", + project_id=PROJECT_ID, + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + ) + assert isinstance(task, BigQueryTableExistenceSensor) + assert task.deferrable is True