diff --git a/astronomer/providers/google/cloud/operators/bigquery.py b/astronomer/providers/google/cloud/operators/bigquery.py index 87292236c..69bdf34e5 100644 --- a/astronomer/providers/google/cloud/operators/bigquery.py +++ b/astronomer/providers/google/cloud/operators/bigquery.py @@ -4,8 +4,6 @@ import warnings from typing import Any -from airflow.exceptions import AirflowException -from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.operators.bigquery import ( BigQueryCheckOperator, BigQueryGetDataOperator, @@ -14,11 +12,6 @@ BigQueryValueCheckOperator, ) -from astronomer.providers.google.cloud.triggers.bigquery import ( - BigQueryValueCheckTrigger, -) -from astronomer.providers.utils.typing_compat import Context - BIGQUERY_JOB_DETAILS_LINK_FMT = "https://console.cloud.google.com/bigquery?j={job_id}" @@ -111,68 +104,23 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.poll_interval = poll_interval -class BigQueryValueCheckOperatorAsync(BigQueryValueCheckOperator): # noqa: D101 +class BigQueryValueCheckOperatorAsync(BigQueryValueCheckOperator): + """ + This class is deprecated. + Please use :class: `~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator` + and set `deferrable` param to `True` instead. + """ + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + ( + "This class is deprecated." + "Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator`" + "and set `deferrable` param to `True` instead." + ), + DeprecationWarning, + stacklevel=2, + ) poll_interval: float = kwargs.pop("poll_interval", 4.0) - super().__init__(*args, **kwargs) + super().__init__(*args, deferrable=True, **kwargs) self.poll_interval = poll_interval - - def _submit_job( - self, - hook: BigQueryHook, - job_id: str, - ) -> BigQueryJob: - """Submit a new job and get the job id for polling the status using Triggerer.""" - configuration = { - "query": { - "query": self.sql, - "useLegacySql": False, - } - } - if self.use_legacy_sql: - configuration["query"]["useLegacySql"] = self.use_legacy_sql - - return hook.insert_job( - configuration=configuration, - project_id=hook.project_id, - location=self.location, - job_id=job_id, - nowait=True, - ) - - def execute(self, context: Context) -> None: # noqa: D102 - hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id) - - job = self._submit_job(hook, job_id="") - context["ti"].xcom_push(key="job_id", value=job.job_id) - if job.running(): - self.defer( - timeout=self.execution_timeout, - trigger=BigQueryValueCheckTrigger( - conn_id=self.gcp_conn_id, - job_id=job.job_id, - project_id=hook.project_id, - sql=self.sql, - pass_value=self.pass_value, - tolerance=self.tol, - impersonation_chain=self.impersonation_chain, - poll_interval=self.poll_interval, - ), - method_name="execute_complete", - ) - else: - super().execute(context=context) - - def execute_complete(self, context: Context, event: dict[str, Any]) -> None: - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "error": - raise AirflowException(event["message"]) - self.log.info( - "%s completed with response %s ", - self.task_id, - event["message"], - ) diff --git a/astronomer/providers/google/cloud/triggers/bigquery.py b/astronomer/providers/google/cloud/triggers/bigquery.py index 268c8e1e9..e9b3ae399 100644 --- a/astronomer/providers/google/cloud/triggers/bigquery.py +++ b/astronomer/providers/google/cloud/triggers/bigquery.py @@ -429,6 +429,9 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger): """ BigQueryValueCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger` instead + :param conn_id: Reference to google cloud connection id :param sql: the sql to be executed :param pass_value: pass value @@ -455,6 +458,15 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, poll_interval: float = 4.0, ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) + super().__init__( conn_id=conn_id, job_id=job_id, diff --git a/tests/google/cloud/operators/test_bigquery.py b/tests/google/cloud/operators/test_bigquery.py index 091ca5e10..40432022c 100644 --- a/tests/google/cloud/operators/test_bigquery.py +++ b/tests/google/cloud/operators/test_bigquery.py @@ -1,13 +1,9 @@ -from unittest import mock -from unittest.mock import MagicMock - -import pytest -from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.google.cloud.operators.bigquery import ( BigQueryCheckOperator, BigQueryGetDataOperator, BigQueryInsertJobOperator, BigQueryIntervalCheckOperator, + BigQueryValueCheckOperator, ) from astronomer.providers.google.cloud.operators.bigquery import ( @@ -17,10 +13,6 @@ BigQueryIntervalCheckOperatorAsync, BigQueryValueCheckOperatorAsync, ) -from astronomer.providers.google.cloud.triggers.bigquery import ( - BigQueryValueCheckTrigger, -) -from tests.utils.airflow_util import create_context TEST_DATASET_LOCATION = "EU" TEST_GCP_PROJECT_ID = "test-project" @@ -86,88 +78,14 @@ def test_init(self): class TestBigQueryValueCheckOperatorAsync: - def _get_value_check_async_operator(self, use_legacy_sql: bool = False): - """Helper function to initialise BigQueryValueCheckOperatorAsync operator""" + def test_init(self): query = "SELECT COUNT(*) FROM Any" pass_val = 2 - - return BigQueryValueCheckOperatorAsync( + task = BigQueryValueCheckOperatorAsync( task_id="check_value", sql=query, pass_value=pass_val, - use_legacy_sql=use_legacy_sql, + use_legacy_sql=True, ) - - @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.execute") - @mock.patch("astronomer.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperatorAsync.defer") - @mock.patch("astronomer.providers.google.cloud.operators.bigquery.BigQueryHook") - def test_bigquery_value_check_async_finish_before_deferred(self, mock_hook, mock_defer, mock_execute): - operator = self._get_value_check_async_operator(True) - job_id = "123456" - hash_ = "hash" - real_job_id = f"{job_id}_{hash_}" - mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) - mock_hook.return_value.insert_job.return_value.running.return_value = False - - operator.execute(create_context(operator)) - assert not mock_defer.called - assert mock_execute.called - - @mock.patch("astronomer.providers.google.cloud.operators.bigquery.BigQueryHook") - def test_bigquery_value_check_async(self, mock_hook): - """ - Asserts that a task is deferred and a BigQueryValueCheckTrigger will be fired - when the BigQueryValueCheckOperatorAsync is executed. - """ - operator = self._get_value_check_async_operator(True) - job_id = "123456" - hash_ = "hash" - real_job_id = f"{job_id}_{hash_}" - mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False) - with pytest.raises(TaskDeferred) as exc: - operator.execute(create_context(operator)) - - assert isinstance( - exc.value.trigger, BigQueryValueCheckTrigger - ), "Trigger is not a BigQueryValueCheckTrigger" - - def test_bigquery_value_check_operator_execute_complete_success(self): - """Tests response message in case of success event""" - operator = self._get_value_check_async_operator() - - assert ( - operator.execute_complete(context=None, event={"status": "success", "message": "Job completed!"}) - is None - ) - - def test_bigquery_value_check_operator_execute_complete_failure(self): - """Tests that an AirflowException is raised in case of error event""" - operator = self._get_value_check_async_operator() - - with pytest.raises(AirflowException): - operator.execute_complete( - context=None, event={"status": "error", "message": "test failure message"} - ) - - @pytest.mark.parametrize( - "kwargs, expected", - [ - ({"sql": "SELECT COUNT(*) from Any"}, "missing keyword argument 'pass_value'"), - ({"pass_value": "Any"}, "missing keyword argument 'sql'"), - ], - ) - def test_bigquery_value_check_missing_param(self, kwargs, expected): - """Assert the exception if require param not pass to BigQueryValueCheckOperatorAsync operator""" - with pytest.raises(AirflowException) as missing_param: - BigQueryValueCheckOperatorAsync(**kwargs) - assert missing_param.value.args[0] == expected - - def test_bigquery_value_check_empty(self): - """Assert the exception if require param not pass to BigQueryValueCheckOperatorAsync operator""" - expected, expected1 = ( - "missing keyword arguments 'sql', 'pass_value'", - "missing keyword arguments 'pass_value', 'sql'", - ) - with pytest.raises(AirflowException) as missing_param: - BigQueryValueCheckOperatorAsync(kwargs={}) - assert (missing_param.value.args[0] == expected) or (missing_param.value.args[0] == expected1) + assert isinstance(task, BigQueryValueCheckOperator) + assert task.deferrable is True