Skip to content

Commit

Permalink
Add encryption_configuration parameter to BigQuery operators
Browse files Browse the repository at this point in the history
  • Loading branch information
VladaZakharova committed Jun 5, 2024
1 parent 5aa43e2 commit 7453f3a
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 6 deletions.
64 changes: 58 additions & 6 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,12 @@ class _BigQueryOperatorsEncryptionConfigurationMixin:
# annotation of the `self`. Then you can inherit this class in the target operator.
# e.g: BigQueryCheckOperator, BigQueryTableCheckOperator
def include_encryption_configuration( # type:ignore[misc]
self: BigQueryCheckOperator | BigQueryTableCheckOperator,
self: BigQueryCheckOperator
| BigQueryTableCheckOperator
| BigQueryValueCheckOperator
| BigQueryColumnCheckOperator
| BigQueryGetDataOperator
| BigQueryIntervalCheckOperator,
configuration: dict,
config_key: str,
) -> None:
Expand Down Expand Up @@ -326,7 +331,9 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
self.log.info("Success.")


class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator):
class BigQueryValueCheckOperator(
_BigQueryDbHookMixin, SQLValueCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin
):
"""Perform a simple value check using sql code.
.. seealso::
Expand All @@ -336,6 +343,13 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator):
:param sql: SQL to execute.
:param use_legacy_sql: Whether to use legacy SQL (true)
or standard SQL (false).
:param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys).
.. code-block:: python
encryption_configuration = {
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
}
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
:param location: The geographic location of the job. See details at:
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
Expand Down Expand Up @@ -370,6 +384,7 @@ def __init__(
sql: str,
pass_value: Any,
tolerance: Any = None,
encryption_configuration: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
use_legacy_sql: bool = True,
location: str | None = None,
Expand All @@ -383,6 +398,7 @@ def __init__(
self.location = location
self.gcp_conn_id = gcp_conn_id
self.use_legacy_sql = use_legacy_sql
self.encryption_configuration = encryption_configuration
self.impersonation_chain = impersonation_chain
self.labels = labels
self.deferrable = deferrable
Expand All @@ -401,6 +417,8 @@ def _submit_job(
},
}

self.include_encryption_configuration(configuration, "query")

return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
Expand Down Expand Up @@ -460,7 +478,9 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
)


class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperator):
class BigQueryIntervalCheckOperator(
_BigQueryDbHookMixin, SQLIntervalCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin
):
"""
Check that the values of metrics given as SQL expressions are within a tolerance of the older ones.
Expand All @@ -481,6 +501,13 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperat
between the current day, and the prior days_back.
:param use_legacy_sql: Whether to use legacy SQL (true)
or standard SQL (false).
:param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys).
.. code-block:: python
encryption_configuration = {
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
}
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
:param location: The geographic location of the job. See details at:
https://cloud.google.com/bigquery/docs/locations#specifying_your_location
Expand Down Expand Up @@ -520,6 +547,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
use_legacy_sql: bool = True,
location: str | None = None,
encryption_configuration: dict | None = None,
impersonation_chain: str | Sequence[str] | None = None,
labels: dict | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
Expand All @@ -538,6 +566,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.use_legacy_sql = use_legacy_sql
self.location = location
self.encryption_configuration = encryption_configuration
self.impersonation_chain = impersonation_chain
self.labels = labels
self.project_id = project_id
Expand All @@ -552,6 +581,7 @@ def _submit_job(
) -> BigQueryJob:
"""Submit a new job and get the job id for polling the status using Triggerer."""
configuration = {"query": {"query": sql, "useLegacySql": self.use_legacy_sql}}
self.include_encryption_configuration(configuration, "query")
return hook.insert_job(
configuration=configuration,
project_id=self.project_id or hook.project_id,
Expand Down Expand Up @@ -608,7 +638,9 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
)


class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator):
class BigQueryColumnCheckOperator(
_BigQueryDbHookMixin, SQLColumnCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin
):
"""
Subclasses the SQLColumnCheckOperator in order to provide a job id for OpenLineage to parse.
Expand All @@ -623,6 +655,13 @@ class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator):
:param partition_clause: a string SQL statement added to a WHERE clause
to partition data
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
:param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys).
.. code-block:: python
encryption_configuration = {
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
}
:param use_legacy_sql: Whether to use legacy SQL (true)
or standard SQL (false).
:param location: The geographic location of the job. See details at:
Expand Down Expand Up @@ -650,6 +689,7 @@ def __init__(
partition_clause: str | None = None,
database: str | None = None,
accept_none: bool = True,
encryption_configuration: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
use_legacy_sql: bool = True,
location: str | None = None,
Expand All @@ -671,6 +711,7 @@ def __init__(
self.database = database
self.accept_none = accept_none
self.gcp_conn_id = gcp_conn_id
self.encryption_configuration = encryption_configuration
self.use_legacy_sql = use_legacy_sql
self.location = location
self.impersonation_chain = impersonation_chain
Expand All @@ -683,7 +724,7 @@ def _submit_job(
) -> BigQueryJob:
"""Submit a new job and get the job id for polling the status using Trigger."""
configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}}

self.include_encryption_configuration(configuration, "query")
return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
Expand Down Expand Up @@ -851,7 +892,7 @@ def execute(self, context=None):
self.log.info("All tests have passed")


class BigQueryGetDataOperator(GoogleCloudBaseOperator):
class BigQueryGetDataOperator(GoogleCloudBaseOperator, _BigQueryOperatorsEncryptionConfigurationMixin):
"""
Fetch data and return it, either from a BigQuery table, or results of a query job.
Expand Down Expand Up @@ -920,6 +961,13 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
from the table. (templated)
:param selected_fields: List of fields to return (comma-separated). If
unspecified, all fields are returned.
:param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys).
.. code-block:: python
encryption_configuration = {
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
}
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
:param location: The location used for the operation.
:param impersonation_chain: Optional service account to impersonate using short-term
Expand Down Expand Up @@ -964,6 +1012,7 @@ def __init__(
selected_fields: str | None = None,
gcp_conn_id: str = "google_cloud_default",
location: str | None = None,
encryption_configuration: dict | None = None,
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: float = 4.0,
Expand All @@ -983,6 +1032,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.location = location
self.impersonation_chain = impersonation_chain
self.encryption_configuration = encryption_configuration
self.project_id = project_id
self.deferrable = deferrable
self.poll_interval = poll_interval
Expand All @@ -996,6 +1046,8 @@ def _submit_job(
) -> BigQueryJob:
get_query = self.generate_query(hook=hook)
configuration = {"query": {"query": get_query, "useLegacySql": self.use_legacy_sql}}
self.include_encryption_configuration(configuration, "query")

"""Submit a new job and get the job id for polling the status using Triggerer."""
return hook.insert_job(
configuration=configuration,
Expand Down
Loading

0 comments on commit 7453f3a

Please sign in to comment.