From 1c87e215f8787ecd6ecf73ac7086a8516d99f4cd Mon Sep 17 00:00:00 2001 From: Ashok Singamaneni Date: Thu, 7 Dec 2023 12:29:08 -0800 Subject: [PATCH] [Feature] updating code to work for retries if the save table fails (#61) * updating code to work for retries if the save table fails. Alter table is not done for stats table * Removing retries as it's not needed --- spark_expectations/core/exceptions.py | 6 +-- spark_expectations/sinks/utils/writer.py | 30 +++++++++++++-- tests/sinks/utils/test_writer.py | 49 ++++++++++++++++-------- 3 files changed, 62 insertions(+), 23 deletions(-) diff --git a/spark_expectations/core/exceptions.py b/spark_expectations/core/exceptions.py index 8f94155..1edb6c3 100644 --- a/spark_expectations/core/exceptions.py +++ b/spark_expectations/core/exceptions.py @@ -43,7 +43,7 @@ class SparkExpectationsMiscException(Exception): class SparkExpectationsSlackNotificationException(Exception): """ - Throw this exception when spark expectations encounters miscellaneous exceptions + Throw this exception when spark expectations encounters exceptions while sending Slack notifications """ pass @@ -51,7 +51,7 @@ class SparkExpectationsSlackNotificationException(Exception): class SparkExpectationsTeamsNotificationException(Exception): """ - Throw this exception when spark expectations encounters miscellaneous exceptions + Throw this exception when spark expectations encounters exceptions while sending Teams notifications """ pass @@ -59,7 +59,7 @@ class SparkExpectationsTeamsNotificationException(Exception): class SparkExpectationsEmailException(Exception): """ - Throw this exception when spark expectations encounters miscellaneous exceptions + Throw this exception when spark expectations encounters exceptions while sending email notifications """ pass diff --git a/spark_expectations/sinks/utils/writer.py b/spark_expectations/sinks/utils/writer.py index 5c0fc7e..3cc2ad3 100644 --- a/spark_expectations/sinks/utils/writer.py +++ b/spark_expectations/sinks/utils/writer.py @@ -83,14 +83,36 @@ def save_df_as_table( if config["options"] is not None and config["options"] != {}: _df_writer = _df_writer.options(**config["options"]) + _log.info("Writing records to table: %s", table_name) + if config["format"] == "bigquery": _df_writer.option("table", table_name).save() else: _df_writer.saveAsTable(name=table_name) - self.spark.sql( - f"ALTER TABLE {table_name} SET TBLPROPERTIES ('product_id' = '{self._context.product_id}')" - ) - _log.info("finished writing records to table: %s,", table_name) + _log.info("finished writing records to table: %s,", table_name) + if not stats_table: + # Fetch table properties + table_properties = self.spark.sql( + f"SHOW TBLPROPERTIES {table_name}" + ).collect() + table_properties_dict = { + row["key"]: row["value"] for row in table_properties + } + + # Set product_id in table properties + if ( + table_properties_dict.get("product_id") is None + or table_properties_dict.get("product_id") + != self._context.product_id + ): + _log.info( + "product_id is not set for table %s in tableproperties, setting it now", + table_name, + ) + self.spark.sql( + f"ALTER TABLE {table_name} SET TBLPROPERTIES ('product_id' = " + f"'{self._context.product_id}')" + ) except Exception as e: raise SparkExpectationsUserInputOrConfigInvalidException( diff --git a/tests/sinks/utils/test_writer.py b/tests/sinks/utils/test_writer.py index f245212..7429804 100644 --- a/tests/sinks/utils/test_writer.py +++ b/tests/sinks/utils/test_writer.py @@ -2,7 +2,6 @@ import unittest.mock from unittest.mock import patch, Mock -import pyspark.sql import pytest from pyspark.sql.functions import col from pyspark.sql.functions import lit, to_timestamp @@ -57,7 +56,7 @@ def fixture_writer(): setattr(mock_context, "get_run_id_name", "meta_dq_run_id") setattr(mock_context, "get_run_date_name", "meta_dq_run_date") mock_context.spark = spark - mock_context.product_id='product1' + mock_context.product_id = 'product1' # Create an instance of the class and set the product_id return SparkExpectationsWriter(mock_context) @@ -160,9 +159,11 @@ def fixture_expected_dq_dataset(): @pytest.mark.parametrize('table_name, options, expected_count', [('employee_table', {'mode': 'overwrite', 'partitionBy': ['department'], "format": "parquet", - 'bucketBy': {'numBuckets':2,'colName':'business_unit'}, 'sortBy': ["eeid"], 'options': {"overwriteSchema": "true", "mergeSchema": "true"}}, 1000), + 'bucketBy': {'numBuckets': 2, 'colName': 'business_unit'}, 'sortBy': ["eeid"], + 'options': {"overwriteSchema": "true", "mergeSchema": "true"}}, 1000), ('employee_table', - {'mode': 'append', "format": "delta", 'partitionBy': [], 'bucketBy': {}, 'sortBy': [], 'options': {"mergeSchema": "true"}}, + {'mode': 'append', "format": "delta", 'partitionBy': [], 'bucketBy': {}, 'sortBy': [], + 'options': {"mergeSchema": "true"}}, 1000) ]) def test_save_df_as_table(table_name, @@ -175,17 +176,31 @@ def test_save_df_as_table(table_name, assert expected_count == spark.sql(f"select * from {table_name}").count() - # Assert - # _spark_set.assert_called_with('spark.sql.session.timeZone', 'Etc/UTC') + # Fetch table properties + table_properties = spark.sql(f"SHOW TBLPROPERTIES {table_name}").collect() + table_properties_dict = {row["key"]: row["value"] for row in table_properties} + + # Check 'product_id' property + assert table_properties_dict.get("product_id") == _fixture_writer._context.product_id + + spark.sql(f"drop table if exists {table_name}") + spark.sql(f"drop table if exists {table_name}_stats") + spark.sql(f"drop table if exists {table_name}_error") + + _fixture_writer.save_df_as_table(_fixture_employee, table_name, options, True) + # Fetch table properties + table_properties = spark.sql(f"SHOW TBLPROPERTIES {table_name}").collect() + table_properties_dict = {row["key"]: row["value"] for row in table_properties} + assert table_properties_dict.get("product_id") is None @patch('pyspark.sql.DataFrameWriter.save', autospec=True, spec_set=True) -def test_save_df_to_table_bq(save, _fixture_writer, _fixture_employee, _fixture_create_employee_table): +def test_save_df_to_table_bq(save, _fixture_writer, _fixture_employee, _fixture_create_employee_table): _fixture_writer.save_df_as_table(_fixture_employee, 'employee_table', {'mode': 'overwrite', 'format': 'bigquery', - 'partitionBy':[], 'bucketBy': {}, 'sortBy':[], 'options':{}}) + 'partitionBy': [], 'bucketBy': {}, + 'sortBy': [], 'options': {}}) save.assert_called_once_with(unittest.mock.ANY) - @pytest.mark.parametrize('table_name, options', [('employee_table', {'mode': 'overwrite', 'partitionBy': ['department'], @@ -439,7 +454,8 @@ def test_write_df_to_table(save_df_as_table, "output_percentage": 0.0, "success_percentage": 0.0, "error_percentage": 100.0, - }, {'mode': 'append', "format": "bigquery", 'partitionBy': [], 'bucketBy': {}, 'sortBy': [], 'options': {"mergeSchema": "true"}}) + }, {'mode': 'append', "format": "bigquery", 'partitionBy': [], 'bucketBy': {}, 'sortBy': [], + 'options': {"mergeSchema": "true"}}) ]) def test_write_error_stats(input_record, expected_result, @@ -502,8 +518,10 @@ def test_write_error_stats(input_record, setattr(_mock_context, 'get_dq_stats_table_name', 'test_dq_stats_table') if writer_config is None: - setattr(_mock_context, "_stats_table_writer_config", WrappedDataFrameWriter().mode("overwrite").format("delta").build()) - setattr(_mock_context, 'get_stats_table_writer_config', WrappedDataFrameWriter().mode("overwrite").format("delta").build()) + setattr(_mock_context, "_stats_table_writer_config", + WrappedDataFrameWriter().mode("overwrite").format("delta").build()) + setattr(_mock_context, 'get_stats_table_writer_config', + WrappedDataFrameWriter().mode("overwrite").format("delta").build()) else: setattr(_mock_context, "_stats_table_writer_config", writer_config) setattr(_mock_context, 'get_stats_table_writer_config', writer_config) @@ -554,9 +572,6 @@ def test_write_error_stats(input_record, "to_json(struct(*)) AS value").collect() - - - @pytest.mark.parametrize('table_name, rule_type', [('test_error_table', 'row_dq' @@ -608,7 +623,8 @@ def test_write_error_records_final_dependent(save_df_as_table, .withColumn("meta_dq_run_date", lit("2022-12-27 10:39:44")) \ .orderBy("id").collect() assert save_df_args[0][2] == table_name - save_df_as_table.assert_called_once_with(_fixture_writer, save_df_args[0][1], table_name, _fixture_writer._context.get_target_and_error_table_writer_config) + save_df_as_table.assert_called_once_with(_fixture_writer, save_df_args[0][1], table_name, + _fixture_writer._context.get_target_and_error_table_writer_config) @pytest.mark.parametrize("test_data, expected_result", [ @@ -647,6 +663,7 @@ def test_generate_summarised_row_dq_res(test_data, expected_result): result = context.get_summarised_row_dq_res assert result == expected_result + @pytest.mark.parametrize('dq_rules, summarised_row_dq, expected_result', [ (