From 7453f3aa4a1e726eaec889d44560188317ccb464 Mon Sep 17 00:00:00 2001 From: Ulada Zakharava Date: Fri, 31 May 2024 09:14:35 +0000 Subject: [PATCH] Add encryption_configuration parameter to BigQuery operators --- .../google/cloud/operators/bigquery.py | 64 ++++++- .../google/cloud/operators/test_bigquery.py | 164 ++++++++++++++++++ 2 files changed, 222 insertions(+), 6 deletions(-) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 83786ae762d96..6160277f28bf7 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -149,7 +149,12 @@ class _BigQueryOperatorsEncryptionConfigurationMixin: # annotation of the `self`. Then you can inherit this class in the target operator. # e.g: BigQueryCheckOperator, BigQueryTableCheckOperator def include_encryption_configuration( # type:ignore[misc] - self: BigQueryCheckOperator | BigQueryTableCheckOperator, + self: BigQueryCheckOperator + | BigQueryTableCheckOperator + | BigQueryValueCheckOperator + | BigQueryColumnCheckOperator + | BigQueryGetDataOperator + | BigQueryIntervalCheckOperator, configuration: dict, config_key: str, ) -> None: @@ -326,7 +331,9 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: self.log.info("Success.") -class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator): +class BigQueryValueCheckOperator( + _BigQueryDbHookMixin, SQLValueCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin +): """Perform a simple value check using sql code. .. seealso:: @@ -336,6 +343,13 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin, SQLValueCheckOperator): :param sql: SQL to execute. :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + + .. code-block:: python + + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. :param location: The geographic location of the job. See details at: https://cloud.google.com/bigquery/docs/locations#specifying_your_location @@ -370,6 +384,7 @@ def __init__( sql: str, pass_value: Any, tolerance: Any = None, + encryption_configuration: dict | None = None, gcp_conn_id: str = "google_cloud_default", use_legacy_sql: bool = True, location: str | None = None, @@ -383,6 +398,7 @@ def __init__( self.location = location self.gcp_conn_id = gcp_conn_id self.use_legacy_sql = use_legacy_sql + self.encryption_configuration = encryption_configuration self.impersonation_chain = impersonation_chain self.labels = labels self.deferrable = deferrable @@ -401,6 +417,8 @@ def _submit_job( }, } + self.include_encryption_configuration(configuration, "query") + return hook.insert_job( configuration=configuration, project_id=hook.project_id, @@ -460,7 +478,9 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: ) -class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperator): +class BigQueryIntervalCheckOperator( + _BigQueryDbHookMixin, SQLIntervalCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin +): """ Check that the values of metrics given as SQL expressions are within a tolerance of the older ones. @@ -481,6 +501,13 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, SQLIntervalCheckOperat between the current day, and the prior days_back. :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + + .. code-block:: python + + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. :param location: The geographic location of the job. See details at: https://cloud.google.com/bigquery/docs/locations#specifying_your_location @@ -520,6 +547,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", use_legacy_sql: bool = True, location: str | None = None, + encryption_configuration: dict | None = None, impersonation_chain: str | Sequence[str] | None = None, labels: dict | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), @@ -538,6 +566,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.use_legacy_sql = use_legacy_sql self.location = location + self.encryption_configuration = encryption_configuration self.impersonation_chain = impersonation_chain self.labels = labels self.project_id = project_id @@ -552,6 +581,7 @@ def _submit_job( ) -> BigQueryJob: """Submit a new job and get the job id for polling the status using Triggerer.""" configuration = {"query": {"query": sql, "useLegacySql": self.use_legacy_sql}} + self.include_encryption_configuration(configuration, "query") return hook.insert_job( configuration=configuration, project_id=self.project_id or hook.project_id, @@ -608,7 +638,9 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: ) -class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator): +class BigQueryColumnCheckOperator( + _BigQueryDbHookMixin, SQLColumnCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin +): """ Subclasses the SQLColumnCheckOperator in order to provide a job id for OpenLineage to parse. @@ -623,6 +655,13 @@ class BigQueryColumnCheckOperator(_BigQueryDbHookMixin, SQLColumnCheckOperator): :param partition_clause: a string SQL statement added to a WHERE clause to partition data :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + + .. code-block:: python + + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } :param use_legacy_sql: Whether to use legacy SQL (true) or standard SQL (false). :param location: The geographic location of the job. See details at: @@ -650,6 +689,7 @@ def __init__( partition_clause: str | None = None, database: str | None = None, accept_none: bool = True, + encryption_configuration: dict | None = None, gcp_conn_id: str = "google_cloud_default", use_legacy_sql: bool = True, location: str | None = None, @@ -671,6 +711,7 @@ def __init__( self.database = database self.accept_none = accept_none self.gcp_conn_id = gcp_conn_id + self.encryption_configuration = encryption_configuration self.use_legacy_sql = use_legacy_sql self.location = location self.impersonation_chain = impersonation_chain @@ -683,7 +724,7 @@ def _submit_job( ) -> BigQueryJob: """Submit a new job and get the job id for polling the status using Trigger.""" configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}} - + self.include_encryption_configuration(configuration, "query") return hook.insert_job( configuration=configuration, project_id=hook.project_id, @@ -851,7 +892,7 @@ def execute(self, context=None): self.log.info("All tests have passed") -class BigQueryGetDataOperator(GoogleCloudBaseOperator): +class BigQueryGetDataOperator(GoogleCloudBaseOperator, _BigQueryOperatorsEncryptionConfigurationMixin): """ Fetch data and return it, either from a BigQuery table, or results of a query job. @@ -920,6 +961,13 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator): from the table. (templated) :param selected_fields: List of fields to return (comma-separated). If unspecified, all fields are returned. + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + + .. code-block:: python + + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud. :param location: The location used for the operation. :param impersonation_chain: Optional service account to impersonate using short-term @@ -964,6 +1012,7 @@ def __init__( selected_fields: str | None = None, gcp_conn_id: str = "google_cloud_default", location: str | None = None, + encryption_configuration: dict | None = None, impersonation_chain: str | Sequence[str] | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: float = 4.0, @@ -983,6 +1032,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.location = location self.impersonation_chain = impersonation_chain + self.encryption_configuration = encryption_configuration self.project_id = project_id self.deferrable = deferrable self.poll_interval = poll_interval @@ -996,6 +1046,8 @@ def _submit_job( ) -> BigQueryJob: get_query = self.generate_query(hook=hook) configuration = {"query": {"query": get_query, "useLegacySql": self.use_legacy_sql}} + self.include_encryption_configuration(configuration, "query") + """Submit a new job and get the job id for polling the status using Triggerer.""" return hook.insert_job( configuration=configuration, diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 16c9cbdb820ba..3fa34467610e4 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -1062,6 +1062,49 @@ def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict): operator.execute_complete(context=None, event={"status": "success", "records": [20]}) mock_log_info.assert_called_with("Total extracted rows: %s", 1) + @pytest.mark.parametrize("as_dict", [True, False]) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryJob") + def test_encryption_configuration(self, mock_job, mock_hook, as_dict): + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } + + mock_hook.return_value.insert_job.return_value = mock_job + mock_hook.return_value.project_id = TEST_GCP_PROJECT_ID + + max_results = 1 + selected_fields = "DATE" + operator = BigQueryGetDataOperator( + job_project_id=TEST_GCP_PROJECT_ID, + gcp_conn_id=GCP_CONN_ID, + task_id=TASK_ID, + job_id="", + max_results=max_results, + dataset_id=TEST_DATASET, + table_id=TEST_TABLE_ID, + selected_fields=selected_fields, + location=TEST_DATASET_LOCATION, + as_dict=as_dict, + encryption_configuration=encryption_configuration, + deferrable=True, + ) + with pytest.raises(TaskDeferred): + operator.execute(MagicMock()) + mock_hook.return_value.insert_job.assert_called_with( + configuration={ + "query": { + "query": f"""select DATE from `{TEST_GCP_PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}` limit 1""", + "useLegacySql": True, + "destinationEncryptionConfiguration": encryption_configuration, + } + }, + project_id=TEST_GCP_PROJECT_ID, + location=TEST_DATASET_LOCATION, + job_id="", + nowait=True, + ) + class TestBigQueryTableDeleteOperator: @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") @@ -2137,6 +2180,40 @@ def test_bigquery_interval_check_operator_without_project_id( nowait=True, ) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryJob") + def test_encryption_configuration_deferrable_mode(self, mock_job, mock_hook): + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } + + mock_hook.return_value.insert_job.return_value = mock_job + mock_hook.return_value.project_id = TEST_GCP_PROJECT_ID + + operator = BigQueryIntervalCheckOperator( + task_id="TASK_ID", + encryption_configuration=encryption_configuration, + location=TEST_DATASET_LOCATION, + metrics_thresholds={"COUNT(*)": 1.5}, + table=TEST_TABLE_ID, + deferrable=True, + ) + with pytest.raises(TaskDeferred): + operator.execute(MagicMock()) + mock_hook.return_value.insert_job.assert_called_with( + configuration={ + "query": { + "query": """SELECT COUNT(*) FROM test-table-id WHERE ds='{{ macros.ds_add(ds, -7) }}'""", + "useLegacySql": True, + "destinationEncryptionConfiguration": encryption_configuration, + } + }, + project_id=TEST_GCP_PROJECT_ID, + location=TEST_DATASET_LOCATION, + job_id="", + nowait=True, + ) + class TestBigQueryCheckOperator: @pytest.mark.db_test @@ -2425,6 +2502,46 @@ def test_bigquery_value_check_operator_execute_complete_failure(self): context=None, event={"status": "error", "message": "test failure message"} ) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryJob") + def test_encryption_configuration_deferrable_mode(self, mock_job, mock_hook): + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } + + mock_job.result.return_value.to_dataframe.return_value = pd.DataFrame( + { + "check_name": ["row_count_check"], + "check_result": [1], + } + ) + mock_hook.return_value.insert_job.return_value = mock_job + mock_hook.return_value.project_id = TEST_GCP_PROJECT_ID + + operator = BigQueryValueCheckOperator( + task_id="TASK_ID", + encryption_configuration=encryption_configuration, + location=TEST_DATASET_LOCATION, + pass_value=2, + sql=f"SELECT COUNT(*) FROM {TEST_DATASET}.{TEST_TABLE_ID}", + deferrable=True, + ) + with pytest.raises(TaskDeferred): + operator.execute(MagicMock()) + mock_hook.return_value.insert_job.assert_called_with( + configuration={ + "query": { + "query": f"""SELECT COUNT(*) FROM {TEST_DATASET}.{TEST_TABLE_ID}""", + "useLegacySql": True, + "destinationEncryptionConfiguration": encryption_configuration, + } + }, + project_id=TEST_GCP_PROJECT_ID, + location=TEST_DATASET_LOCATION, + job_id="", + nowait=True, + ) + @pytest.mark.db_test class TestBigQueryColumnCheckOperator: @@ -2495,6 +2612,53 @@ def test_bigquery_column_check_operator_fails( with pytest.raises(AirflowException): ti.task.execute(MagicMock()) + @pytest.mark.parametrize( + "check_type, check_value, check_result", + [ + ("equal_to", 0, 0), + ("greater_than", 0, 1), + ("less_than", 0, -1), + ], + ) + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryJob") + def test_encryption_configuration(self, mock_job, mock_hook, check_type, check_value, check_result): + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } + + mock_job.result.return_value.to_dataframe.return_value = pd.DataFrame( + {"col_name": ["col1"], "check_type": ["min"], "check_result": [check_result]} + ) + mock_hook.return_value.insert_job.return_value = mock_job + mock_hook.return_value.project_id = TEST_GCP_PROJECT_ID + + operator = BigQueryColumnCheckOperator( + task_id="TASK_ID", + encryption_configuration=encryption_configuration, + table=f"{TEST_DATASET}.{TEST_TABLE_ID}", + column_mapping={"col1": {"min": {check_type: check_value}}}, + location=TEST_DATASET_LOCATION, + ) + + operator.execute(MagicMock()) + mock_hook.return_value.insert_job.assert_called_with( + configuration={ + "query": { + "query": f"""SELECT col_name, check_type, check_result FROM ( + SELECT 'col1' AS col_name, 'min' AS check_type, col1_min AS check_result + FROM (SELECT MIN(col1) AS col1_min FROM {TEST_DATASET}.{TEST_TABLE_ID} ) AS sq + ) AS check_columns""", + "useLegacySql": True, + "destinationEncryptionConfiguration": encryption_configuration, + } + }, + project_id=TEST_GCP_PROJECT_ID, + location=TEST_DATASET_LOCATION, + job_id="", + nowait=False, + ) + class TestBigQueryTableCheckOperator: @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")