Skip to content

Commit

Permalink
feat(providers/google): deprecate BigQueryValueCheckOperatorAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Jan 24, 2024
1 parent c67354e commit 13dfa25
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 157 deletions.
86 changes: 17 additions & 69 deletions astronomer/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import warnings
from typing import Any

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryCheckOperator,
BigQueryGetDataOperator,
Expand All @@ -14,11 +12,6 @@
BigQueryValueCheckOperator,
)

from astronomer.providers.google.cloud.triggers.bigquery import (
BigQueryValueCheckTrigger,
)
from astronomer.providers.utils.typing_compat import Context

BIGQUERY_JOB_DETAILS_LINK_FMT = "https://console.cloud.google.com/bigquery?j={job_id}"


Expand Down Expand Up @@ -111,68 +104,23 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.poll_interval = poll_interval


class BigQueryValueCheckOperatorAsync(BigQueryValueCheckOperator): # noqa: D101
class BigQueryValueCheckOperatorAsync(BigQueryValueCheckOperator):
"""
This class is deprecated.
Please use :class: `~airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator`
and set `deferrable` param to `True` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This class is deprecated."
"Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator`"
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
poll_interval: float = kwargs.pop("poll_interval", 4.0)
super().__init__(*args, **kwargs)
super().__init__(*args, deferrable=True, **kwargs)
self.poll_interval = poll_interval

def _submit_job(
self,
hook: BigQueryHook,
job_id: str,
) -> BigQueryJob:
"""Submit a new job and get the job id for polling the status using Triggerer."""
configuration = {
"query": {
"query": self.sql,
"useLegacySql": False,
}
}
if self.use_legacy_sql:
configuration["query"]["useLegacySql"] = self.use_legacy_sql

return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
location=self.location,
job_id=job_id,
nowait=True,
)

def execute(self, context: Context) -> None: # noqa: D102
hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id)

job = self._submit_job(hook, job_id="")
context["ti"].xcom_push(key="job_id", value=job.job_id)
if job.running():
self.defer(
timeout=self.execution_timeout,
trigger=BigQueryValueCheckTrigger(
conn_id=self.gcp_conn_id,
job_id=job.job_id,
project_id=hook.project_id,
sql=self.sql,
pass_value=self.pass_value,
tolerance=self.tol,
impersonation_chain=self.impersonation_chain,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)
else:
super().execute(context=context)

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(
"%s completed with response %s ",
self.task_id,
event["message"],
)
12 changes: 12 additions & 0 deletions astronomer/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,9 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
"""
BigQueryValueCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger` instead
:param conn_id: Reference to google cloud connection id
:param sql: the sql to be executed
:param pass_value: pass value
Expand All @@ -455,6 +458,15 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: float = 4.0,
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger`"
),
DeprecationWarning,
stacklevel=2,
)

super().__init__(
conn_id=conn_id,
job_id=job_id,
Expand Down
94 changes: 6 additions & 88 deletions tests/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from unittest import mock
from unittest.mock import MagicMock

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryCheckOperator,
BigQueryGetDataOperator,
BigQueryInsertJobOperator,
BigQueryIntervalCheckOperator,
BigQueryValueCheckOperator,
)

from astronomer.providers.google.cloud.operators.bigquery import (
Expand All @@ -17,10 +13,6 @@
BigQueryIntervalCheckOperatorAsync,
BigQueryValueCheckOperatorAsync,
)
from astronomer.providers.google.cloud.triggers.bigquery import (
BigQueryValueCheckTrigger,
)
from tests.utils.airflow_util import create_context

TEST_DATASET_LOCATION = "EU"
TEST_GCP_PROJECT_ID = "test-project"
Expand Down Expand Up @@ -86,88 +78,14 @@ def test_init(self):


class TestBigQueryValueCheckOperatorAsync:
def _get_value_check_async_operator(self, use_legacy_sql: bool = False):
"""Helper function to initialise BigQueryValueCheckOperatorAsync operator"""
def test_init(self):
query = "SELECT COUNT(*) FROM Any"
pass_val = 2

return BigQueryValueCheckOperatorAsync(
task = BigQueryValueCheckOperatorAsync(
task_id="check_value",
sql=query,
pass_value=pass_val,
use_legacy_sql=use_legacy_sql,
use_legacy_sql=True,
)

@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator.execute")
@mock.patch("astronomer.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperatorAsync.defer")
@mock.patch("astronomer.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_value_check_async_finish_before_deferred(self, mock_hook, mock_defer, mock_execute):
operator = self._get_value_check_async_operator(True)
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
mock_hook.return_value.insert_job.return_value.running.return_value = False

operator.execute(create_context(operator))
assert not mock_defer.called
assert mock_execute.called

@mock.patch("astronomer.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_value_check_async(self, mock_hook):
"""
Asserts that a task is deferred and a BigQueryValueCheckTrigger will be fired
when the BigQueryValueCheckOperatorAsync is executed.
"""
operator = self._get_value_check_async_operator(True)
job_id = "123456"
hash_ = "hash"
real_job_id = f"{job_id}_{hash_}"
mock_hook.return_value.insert_job.return_value = MagicMock(job_id=real_job_id, error_result=False)
with pytest.raises(TaskDeferred) as exc:
operator.execute(create_context(operator))

assert isinstance(
exc.value.trigger, BigQueryValueCheckTrigger
), "Trigger is not a BigQueryValueCheckTrigger"

def test_bigquery_value_check_operator_execute_complete_success(self):
"""Tests response message in case of success event"""
operator = self._get_value_check_async_operator()

assert (
operator.execute_complete(context=None, event={"status": "success", "message": "Job completed!"})
is None
)

def test_bigquery_value_check_operator_execute_complete_failure(self):
"""Tests that an AirflowException is raised in case of error event"""
operator = self._get_value_check_async_operator()

with pytest.raises(AirflowException):
operator.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

@pytest.mark.parametrize(
"kwargs, expected",
[
({"sql": "SELECT COUNT(*) from Any"}, "missing keyword argument 'pass_value'"),
({"pass_value": "Any"}, "missing keyword argument 'sql'"),
],
)
def test_bigquery_value_check_missing_param(self, kwargs, expected):
"""Assert the exception if require param not pass to BigQueryValueCheckOperatorAsync operator"""
with pytest.raises(AirflowException) as missing_param:
BigQueryValueCheckOperatorAsync(**kwargs)
assert missing_param.value.args[0] == expected

def test_bigquery_value_check_empty(self):
"""Assert the exception if require param not pass to BigQueryValueCheckOperatorAsync operator"""
expected, expected1 = (
"missing keyword arguments 'sql', 'pass_value'",
"missing keyword arguments 'pass_value', 'sql'",
)
with pytest.raises(AirflowException) as missing_param:
BigQueryValueCheckOperatorAsync(kwargs={})
assert (missing_param.value.args[0] == expected) or (missing_param.value.args[0] == expected1)
assert isinstance(task, BigQueryValueCheckOperator)
assert task.deferrable is True

0 comments on commit 13dfa25

Please sign in to comment.