Skip to content

Commit

Permalink
feat(providers/google): deprecate BigQueryIntervalCheckOperatorAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Jan 24, 2024
1 parent 7c6eb9d commit c67354e
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 170 deletions.
107 changes: 13 additions & 94 deletions astronomer/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)

from astronomer.providers.google.cloud.triggers.bigquery import (
BigQueryIntervalCheckTrigger,
BigQueryValueCheckTrigger,
)
from astronomer.providers.utils.typing_compat import Context
Expand Down Expand Up @@ -92,105 +91,25 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

class BigQueryIntervalCheckOperatorAsync(BigQueryIntervalCheckOperator):
"""
Checks asynchronously that the values of metrics given as SQL expressions are within
a certain tolerance of the ones from days_back before.
This method constructs a query like so ::
SELECT {metrics_threshold_dict_key} FROM {table}
WHERE {date_filter_column}=<date>
:param table: the table name
:param days_back: number of days between ds and the ds we want to check
against. Defaults to 7 days
:param metrics_thresholds: a dictionary of ratios indexed by metrics, for
example 'COUNT(*)': 1.5 would require a 50 percent or less difference
between the current day, and the prior days_back.
:param use_legacy_sql: Whether to use legacy SQL (true)
or standard SQL (false).
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
:param location: The geographic location of the job. See details at:
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param labels: a dictionary containing labels for the table, passed to BigQuery
:param poll_interval: polling period in seconds to check for the status of job. Defaults to 4 seconds.
This class is deprecated.
Please use :class: `~airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckOperator`
and set `deferrable` param to `True` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This class is deprecated."
"Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckOperator`"
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
poll_interval: float = kwargs.pop("poll_interval", 4.0)
super().__init__(*args, **kwargs)
super().__init__(*args, deferrable=True, **kwargs)
self.poll_interval = poll_interval

def _submit_job(
self,
hook: BigQueryHook,
sql: str,
job_id: str,
) -> BigQueryJob:
"""Submit a new job and get the job id for polling the status using Triggerer."""
configuration = {"query": {"query": sql}}
return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
location=self.location,
job_id=job_id,
nowait=True,
)

def execute(self, context: Context) -> None:
"""Execute the job in sync mode and defers the trigger with job id to poll for the status"""
hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id)
self.log.info("Using ratio formula: %s", self.ratio_formula)

self.log.info("Executing SQL check: %s", self.sql1)
job_1 = self._submit_job(hook, sql=self.sql1, job_id="")
context["ti"].xcom_push(key="job_id", value=job_1.job_id)

self.log.info("Executing SQL check: %s", self.sql2)
job_2 = self._submit_job(hook, sql=self.sql2, job_id="")
if job_1.running() or job_2.running():
self.defer(
timeout=self.execution_timeout,
trigger=BigQueryIntervalCheckTrigger(
conn_id=self.gcp_conn_id,
first_job_id=job_1.job_id,
second_job_id=job_2.job_id,
project_id=hook.project_id,
table=self.table,
metrics_thresholds=self.metrics_thresholds,
date_filter_column=self.date_filter_column,
days_back=self.days_back,
ratio_formula=self.ratio_formula,
ignore_zero=self.ignore_zero,
impersonation_chain=self.impersonation_chain,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)
else:
super().execute(context=context)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])

self.log.info(
"%s completed with response %s ",
self.task_id,
event["status"],
)


class BigQueryValueCheckOperatorAsync(BigQueryValueCheckOperator): # noqa: D101
def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand Down
12 changes: 12 additions & 0 deletions astronomer/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
"""
BigQueryIntervalCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger` instead
:param conn_id: Reference to google cloud connection id
:param first_job_id: The ID of the job 1 performed
:param second_job_id: The ID of the job 2 performed
Expand Down Expand Up @@ -303,6 +306,15 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: float = 4.0,
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger`"
),
DeprecationWarning,
stacklevel=2,
)

super().__init__(
conn_id=conn_id,
job_id=first_job_id,
Expand Down
81 changes: 5 additions & 76 deletions tests/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
BigQueryCheckOperator,
BigQueryGetDataOperator,
BigQueryInsertJobOperator,
BigQueryIntervalCheckOperator,
)

from astronomer.providers.google.cloud.operators.bigquery import (
Expand All @@ -17,7 +18,6 @@
BigQueryValueCheckOperatorAsync,
)
from astronomer.providers.google.cloud.triggers.bigquery import (
BigQueryIntervalCheckTrigger,
BigQueryValueCheckTrigger,
)
from tests.utils.airflow_util import create_context
Expand Down Expand Up @@ -61,86 +61,15 @@ def test_init(self):


class TestBigQueryIntervalCheckOperatorAsync:
def test_bigquery_interval_check_operator_execute_complete(self):
"""Asserts that logging occurs as expected"""

operator = BigQueryIntervalCheckOperatorAsync(
task_id="bq_interval_check_operator_execute_complete",
table="test_table",
metrics_thresholds={"COUNT(*)": 1.5},
location=TEST_DATASET_LOCATION,
)

with mock.patch.object(operator.log, "info") as mock_log_info:
operator.execute_complete(context=None, event={"status": "success", "message": "Job completed"})
mock_log_info.assert_called_with(
"%s completed with response %s ", "bq_interval_check_operator_execute_complete", "success"
)

def test_bigquery_interval_check_operator_execute_failure(self, context):
"""Tests that an AirflowException is raised in case of error event"""

operator = BigQueryIntervalCheckOperatorAsync(
task_id="bq_interval_check_operator_execute_complete",
table="test_table",
metrics_thresholds={"COUNT(*)": 1.5},
location=TEST_DATASET_LOCATION,
)

with pytest.raises(AirflowException):
operator.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckOperator.execute")
@mock.patch("astronomer.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckOperator.defer")
@mock.patch("astronomer.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_interval_check_operator_async_finish_before_defer(
self, mock_hook, mock_defer, mock_execute
):
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
mock_hook.return_value.insert_job.return_value.running.return_value = False

op = BigQueryIntervalCheckOperatorAsync(
task_id="bq_interval_check_operator_execute_complete",
table="test_table",
metrics_thresholds={"COUNT(*)": 1.5},
location=TEST_DATASET_LOCATION,
)

op.execute(create_context(op))
assert not mock_defer.called
assert mock_execute.called

@mock.patch("astronomer.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_interval_check_operator_async(self, mock_hook):
"""
Asserts that a task is deferred and a BigQueryIntervalCheckTrigger will be fired
when the BigQueryIntervalCheckOperatorAsync is executed.
"""
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"

mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)

op = BigQueryIntervalCheckOperatorAsync(
def test_init(self):
task = BigQueryIntervalCheckOperatorAsync(
task_id="bq_interval_check_operator_execute_complete",
table="test_table",
metrics_thresholds={"COUNT(*)": 1.5},
location=TEST_DATASET_LOCATION,
)

with pytest.raises(TaskDeferred) as exc:
op.execute(create_context(op))

assert isinstance(
exc.value.trigger, BigQueryIntervalCheckTrigger
), "Trigger is not a BigQueryIntervalCheckTrigger"
assert isinstance(task, BigQueryIntervalCheckOperator)
assert task.deferrable is True


class TestBigQueryGetDataOperatorAsync:
Expand Down

0 comments on commit c67354e

Please sign in to comment.