From ca1ceb5c6a1010b561b211d9b7fa3760f9d08bc6 Mon Sep 17 00:00:00 2001 From: Ashok Singamaneni Date: Thu, 28 Sep 2023 11:31:03 +0530 Subject: [PATCH] Feature remove mandatory args (#40) * making changes and refactoring the code. * Adding Wrapped DataFrameWriter and unitests for the same * ignoring agg_dq and query_dq * Rearranging lot of code, removed delta. Enabled custom writers in different formats * fixing tests for the sinks and updating README * fixing kafka configurations * Feature add reader * changing rules table along with few other changes * reader and writer modifications --------- Co-authored-by: phanikumarvemuri * Updating formatting, removing delta as dependency * Adding examples for expectations * adding documentation Co-authored-by: phanikumarvemuri * Updating tests * Updating documentation * Updating WrappedDataframeWriter --------- Co-authored-by: phanikumarvemuri --- CONTRIBUTORS.md | 3 +- README.md | 78 +- docs/api/delta_sink_plugin.md | 11 - .../{sample_dq.md => sample_dq_bigquery.md} | 6 +- docs/api/sample_dq_delta.md | 8 + docs/api/sample_dq_iceberg.md | 8 + docs/bigquery.md | 92 + .../adoption_versions_comparsion.md | 21 +- docs/configurations/rules.md | 54 +- docs/delta.md | 82 + docs/examples.md | 462 +--- docs/iceberg.md | 88 + docs/index.md | 14 +- mike-mkdocs_vbm2uyw.yml | 146 -- mkdocs.yml | 7 +- poetry.lock | 118 +- prospector.yaml | 2 +- pyproject.toml | 2 - spark_expectations/config/user_config.py | 9 - spark_expectations/core/__init__.py | 7 +- spark_expectations/core/context.py | 58 +- spark_expectations/core/expectations.py | 331 ++- spark_expectations/examples/base_setup.py | 243 +- .../docker_kafka_start_script.sh | 2 +- .../docker_kafka_stop_script.sh | 2 +- .../examples/sample_dq_bigquery.py | 120 + .../{sample_dq.py => sample_dq_delta.py} | 57 +- .../examples/sample_dq_iceberg.py | 103 + .../push/spark_expectations_notify.py | 19 +- spark_expectations/secrets/__init__.py | 7 +- spark_expectations/sinks/__init__.py | 6 - .../sinks/plugins/base_writer.py | 2 +- .../sinks/plugins/delta_writer.py | 42 - .../sinks/plugins/kafka_writer.py | 6 +- .../sinks/utils/collect_statistics.py | 1 - spark_expectations/sinks/utils/writer.py | 252 +- spark_expectations/utils/actions.py | 37 +- spark_expectations/utils/reader.py | 211 +- spark_expectations/utils/regulate_flow.py | 17 +- tests/config/test_user_config.py | 14 - tests/core/test_context.py | 520 ++-- tests/core/test_expectations.py | 2216 ++++++++--------- .../push/test_spark_expectations_notify.py | 39 +- tests/sinks/plugins/test_delta_writer.py | 90 - ...est_nsp_writer.py => test_kafka_writer.py} | 4 +- tests/sinks/test__init__.py | 13 +- tests/sinks/utils/test_collect_statistics.py | 38 +- tests/sinks/utils/test_writer.py | 229 +- tests/utils/test_actions.py | 40 +- tests/utils/test_reader.py | 356 +-- tests/utils/test_regulate_flow.py | 29 +- 51 files changed, 2984 insertions(+), 3338 deletions(-) delete mode 100644 docs/api/delta_sink_plugin.md rename docs/api/{sample_dq.md => sample_dq_bigquery.md} (59%) create mode 100644 docs/api/sample_dq_delta.md create mode 100644 docs/api/sample_dq_iceberg.md create mode 100644 docs/bigquery.md create mode 100644 docs/delta.md create mode 100644 docs/iceberg.md delete mode 100644 mike-mkdocs_vbm2uyw.yml create mode 100644 spark_expectations/examples/sample_dq_bigquery.py rename spark_expectations/examples/{sample_dq.py => sample_dq_delta.py} (71%) create mode 100644 spark_expectations/examples/sample_dq_iceberg.py delete mode 100644 spark_expectations/sinks/plugins/delta_writer.py delete mode 100644 tests/sinks/plugins/test_delta_writer.py rename tests/sinks/plugins/{test_nsp_writer.py => test_kafka_writer.py} (97%) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index b473972f..1d4cda70 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -6,6 +6,8 @@ Thanks to the contributors who helped on this project apart from the authors * [Teja Dogiparthi](https://github.com/Tejadogiparthi) * [Phani Kumar Vemuri](https://www.linkedin.com/in/vemuriphani/) +* [Sarath Chandra Bandaru](https://www.linkedin.com/in/sarath-chandra-bandaru/) +* [Holden Karau](https://www.linkedin.com/in/holdenkarau/) # Honorary Mentions Thanks to the team below for invaluable insights and support throughout the initial release of this project @@ -15,4 +17,3 @@ Thanks to the team below for invaluable insights and support throughout the init * [Aditya Chaturvedi](https://www.linkedin.com/in/chaturvediaditya/) * [Scott Haines](https://www.linkedin.com/in/scotthaines/) * [Arijit Banerjee](https://www.linkedin.com/in/massborn/) -* [Sarath Chandra Bandaru](https://www.linkedin.com/in/sarath-chandra-bandaru/) diff --git a/README.md b/README.md index 8111cf4a..49e39699 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,8 @@ [![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/) [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) ![PYPI version](https://img.shields.io/pypi/v/spark-expectations.svg) - +![PYPI - Downloads](https://static.pepy.tech/badge/spark-expectations) +![PYPI - Python Version](https://img.shields.io/pypi/pyversions/spark-expectations.svg)

Spark Expectations is a specialized tool designed with the primary goal of maintaining data integrity within your processing pipeline. @@ -70,7 +71,7 @@ is provided in the appropriate fields. ```python from spark_expectations.config.user_config import * -se_global_spark_Conf = { +se_user_conf = { se_notifications_enable_email: False, se_notifications_email_smtp_host: "mailhost.nike.com", se_notifications_email_smtp_port: 25, @@ -91,66 +92,49 @@ se_global_spark_Conf = { For all the below examples the below import and SparkExpectations class instantiation is mandatory -```python -from spark_expectations.core.expectations import SparkExpectations +1. Instantiate `SparkExpectations` class which has all the required functions for running data quality rules +```python +from spark_expectations.core.expectations import SparkExpectations, WrappedDataFrameWriter +from pyspark.sql import SparkSession +spark: SparkSession = SparkSession.builder.getOrCreate() +writer = WrappedDataFrameWriter().mode("append").format("delta") +# writer = WrappedDataFrameWriter().mode("append").format("iceberg") # product_id should match with the "product_id" in the rules table -se: SparkExpectations = SparkExpectations(product_id="your-products-id") +se: SparkExpectations = SparkExpectations( + product_id="your_product", + rules_df=spark.table("dq_spark_local.dq_rules"), + stats_table="dq_spark_local.dq_stats", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + # stats_streaming_options={user_config.se_enable_streaming: False}, +) ``` -1. Instantiate `SparkExpectations` class which has all the required functions for running data quality rules - +2. Decorate the function with `@se.with_expectations` decorator ```python -from spark_expectations.config.user_config import * - - -@se.with_expectations( - se.reader.get_rules_from_table( - product_rules_table="pilot_nonpub.dq.dq_rules", - table_name="pilot_nonpub.dq_employee.employee", - dq_stats_table_name="pilot_nonpub.dq.dq_stats" - ), - write_to_table=True, - write_to_temp_table=True, - row_dq=True, - agg_dq={ - se_agg_dq: True, - se_source_agg_dq: True, - se_final_agg_dq: True, - }, - query_dq={ - se_query_dq: True, - se_source_query_dq: True, - se_final_query_dq: True, - se_target_table_view: "order", - }, - spark_conf=se_global_spark_Conf, +from spark_expectations.config.user_config import * +from pyspark.sql import DataFrame +import os + +@se.with_expectations( + target_table="dq_spark_local.customer_order", + write_to_table=True, + user_conf=se_user_conf, + target_table_view="order", ) def build_new() -> DataFrame: + # Return the dataframe on which Spark-Expectations needs to be run _df_order: DataFrame = ( spark.read.option("header", "true") .option("inferSchema", "true") .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) ) - _df_order.createOrReplaceTempView("order") - - _df_product: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/product.csv")) - ) - _df_product.createOrReplaceTempView("product") - - _df_customer: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/customer.csv")) - ) - - _df_customer.createOrReplaceTempView("customer") + _df_order.createOrReplaceTempView("order") return _df_order ``` diff --git a/docs/api/delta_sink_plugin.md b/docs/api/delta_sink_plugin.md deleted file mode 100644 index 91d8d265..00000000 --- a/docs/api/delta_sink_plugin.md +++ /dev/null @@ -1,11 +0,0 @@ ---- -search: - exclude: true ---- - -::: spark_expectations.sinks.plugins.delta_writer - handler: python - options: - filters: - - "!^_[^_]" - - "!^__[^__]" \ No newline at end of file diff --git a/docs/api/sample_dq.md b/docs/api/sample_dq_bigquery.md similarity index 59% rename from docs/api/sample_dq.md rename to docs/api/sample_dq_bigquery.md index 238ca071..7d0ccb49 100644 --- a/docs/api/sample_dq.md +++ b/docs/api/sample_dq_bigquery.md @@ -1,9 +1,5 @@ ---- -search: - exclude: true ---- -::: spark_expectations.examples.sample_dq +::: spark_expectations.examples.sample_dq_bigquery handler: python options: filters: diff --git a/docs/api/sample_dq_delta.md b/docs/api/sample_dq_delta.md new file mode 100644 index 00000000..db6af0f3 --- /dev/null +++ b/docs/api/sample_dq_delta.md @@ -0,0 +1,8 @@ + +::: spark_expectations.examples.sample_dq_delta + handler: python + options: + filters: + - "!^_[^_]" + - "!^__[^__]" + \ No newline at end of file diff --git a/docs/api/sample_dq_iceberg.md b/docs/api/sample_dq_iceberg.md new file mode 100644 index 00000000..29582af7 --- /dev/null +++ b/docs/api/sample_dq_iceberg.md @@ -0,0 +1,8 @@ + +::: spark_expectations.examples.sample_dq_iceberg + handler: python + options: + filters: + - "!^_[^_]" + - "!^__[^__]" + \ No newline at end of file diff --git a/docs/bigquery.md b/docs/bigquery.md new file mode 100644 index 00000000..0299bece --- /dev/null +++ b/docs/bigquery.md @@ -0,0 +1,92 @@ +### Example - Write to Delta + +Setup SparkSession for bigquery to test in your local environment. Configure accordingly for higher environments. +Refer to Examples in [base_setup.py](../spark_expectations/examples/base_setup.py) and +[delta.py](../spark_expectations/examples/sample_dq_bigquery.py) + +```python title="spark_session" +from pyspark.sql import SparkSession + +builder = ( + SparkSession.builder.config( + "spark.jars.packages", + "com.google.cloud.spark:spark-bigquery-with-dependencies_2.12:0.30.0", + ) +) +spark = builder.getOrCreate() + +spark._jsc.hadoopConfiguration().set( + "fs.gs.impl", "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem" +) +spark.conf.set("viewsEnabled", "true") +spark.conf.set("materializationDataset", "") +``` + +Below is the configuration that can be used to run SparkExpectations and write to DeltaLake + +```python title="iceberg_write" +import os +from pyspark.sql import DataFrame +from spark_expectations.core.expectations import ( + SparkExpectations, + WrappedDataFrameWriter, +) +from spark_expectations.config.user_config import Constants as user_config + +os.environ[ + "GOOGLE_APPLICATION_CREDENTIALS" +] = "path_to_your_json_credential_file" # This is needed for spark write to bigquery +writer = ( + WrappedDataFrameWriter().mode("overwrite") + .format("bigquery") + .option("createDisposition", "CREATE_IF_NEEDED") + .option("writeMethod", "direct") +) + +se: SparkExpectations = SparkExpectations( + product_id="your_product", + rules_df=spark.read.format("bigquery").load( + ".." + ), + stats_table="..", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + stats_streaming_options={user_config.se_enable_streaming: False} +) + + +# Commented fields are optional or required when notifications are enabled +user_conf = { + user_config.se_notifications_enable_email: False, + # user_config.se_notifications_email_smtp_host: "mailhost.com", + # user_config.se_notifications_email_smtp_port: 25, + # user_config.se_notifications_email_from: "", + # user_config.se_notifications_email_to_other_mail_id: "", + # user_config.se_notifications_email_subject: "spark expectations - data quality - notifications", + user_config.se_notifications_enable_slack: False, + # user_config.se_notifications_slack_webhook_url: "", + # user_config.se_notifications_on_start: True, + # user_config.se_notifications_on_completion: True, + # user_config.se_notifications_on_fail: True, + # user_config.se_notifications_on_error_drop_exceeds_threshold_breach: True, + # user_config.se_notifications_on_error_drop_threshold: 15, +} + + +@se.with_expectations( + target_table="..", + write_to_table=True, + user_conf=user_conf, + target_table_view="..", +) +def build_new() -> DataFrame: + _df_order: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) + ) + _df_order.createOrReplaceTempView("order") + + return _df_order +``` diff --git a/docs/configurations/adoption_versions_comparsion.md b/docs/configurations/adoption_versions_comparsion.md index 8895eede..a26d24d6 100644 --- a/docs/configurations/adoption_versions_comparsion.md +++ b/docs/configurations/adoption_versions_comparsion.md @@ -4,16 +4,17 @@ Please find the difference in the changes with different version, latest three v -| stage | 0.6.0 | 0.7.0 | 0.8.0 | -| :------------------| :----------- | :----- | ------------------ | -| rules table schema changes | refer rule table creation [here](https://glowing-umbrella-j8jnolr.pages.github.io/0.7.0/getting-started/setup/) | added three additional column
1.`enable_for_source_dq_validation(boolean)`
2.`enable_for_target_dq_validation(boolean)`
3.`is_active(boolean)`

documentation found [here](https://glowing-umbrella-j8jnolr.pages.github.io/0.7.0/getting-started/setup/) | added additional two column
1.`enable_error_drop_alert(boolean)`
2.`error_drop_thresholdt(int)`

documentation found [here](https://glowing-umbrella-j8jnolr.pages.github.io/0.8.0/getting-started/setup/)| -| rule table creation required | yes | yes - creation not required if you're upgrading from old version but schema changes required | yes - creation not required if you're upgrading from old version but schema changes required | -| stats table schema changes | refer rule table creation [here](https://glowing-umbrella-j8jnolr.pages.github.io/0.7.0/getting-started/setup/) | added additional columns
1. `source_query_dq_results`
2. `final_query_dq_results`
3. `row_dq_res_summary`
4. `dq_run_time`
5. `dq_rules`

renamed columns
1. `runtime` to `meta_dq_run_time`
2. `run_date` to `meta_dq_run_date`
3. `run_id` to `meta_dq_run_id`

documentation found [here](https://glowing-umbrella-j8jnolr.pages.github.io/0.7.0/getting-started/setup/)| remains same | -| stats table creation required | yes | yes - creation not required if you're upgrading from old version but schema changes required | automated | -| notification config setting | define global notification param, register as env variable and place in the `__init__.py` file for multiple usage, [example](https://glowing-umbrella-j8jnolr.pages.github.io/0.7.0/examples/) | Define a global notification parameter in the `__init__.py` file to be used in multiple instances where the spark_conf parameter needs to be passed within the with_expectations function. [example](https://glowing-umbrella-j8jnolr.pages.github.io/0.7.0/examples/) | remains same | -| secret store and kafka authentication details | not applicable | not applicable | Create a dictionary that contains your secret configuration values and register in `__init__.py` for multiple usage, [example](https://glowing-umbrella-j8jnolr.pages.github.io/0.8.0/examples/) | -| spark expectations initialisation | create SparkExpectations class object using the `SparkExpectations` library and by passing the `product_id` | create spark expectations class object using `SpakrExpectations` by passing `product_id` and optional parameter `debugger` [example](https://glowing-umbrella-j8jnolr.pages.github.io/0.7.0/examples/) | create spark expectations class object using `SpakrExpectations` by passing `product_id` and additional optional parameter `debugger`, `stats_streaming_options` [example](https://glowing-umbrella-j8jnolr.pages.github.io/0.8.0/examples/) | -| spark expectations decorator | The decorator allows for configuration by passing individual parameters to each decorator. However, registering a DataFrame view within a decorated function is not supported for implementations of query_dq [example](https://glowing-umbrella-j8jnolr.pages.github.io/0.7.0/examples/) | The decorator allows configurations to be logically grouped through a dictionary passed as a parameter to the decorator. Additionally, registering a DataFrame view within a decorated function is supported for implementations of query_dq. [example](https://glowing-umbrella-j8jnolr.pages.github.io/0.7.0/examples/) | remains same | +| stage | 0.8.0 | 1.0.0 | +|:----------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------| +| rules table schema changes | added additional two column
1.`enable_error_drop_alert(boolean)`
2.`error_drop_thresholdt(int)`

documentation found [here](https://engineering.nike.com/spark-expectations/0.8.1/getting-started/setup/) | Remains same | +| rule table creation required | yes - creation not required if you're upgrading from old version but schema changes required | yes - creation not required if you're upgrading from old version but schema changes required | +| stats table schema changes | remains same | Remains Same | +| stats table creation required | automated | Remains Same | +| notification config setting | remains same | Remains Same | +| secret store and kafka authentication details | Create a dictionary that contains your secret configuration values and register in `__init__.py` for multiple usage, [example](https://engineering.nike.com/spark-expectations/0.8.1/examples/) | Remains Same. You can disable streaming if needed, in SparkExpectations class | +| spark expectations initialisation | create spark expectations class object using `SpakrExpectations` by passing `product_id` and additional optional parameter `debugger`, `stats_streaming_options` [example](https://engineering.nike.com/spark-expectations/0.8.1/examples/) | New arguments are added. Please follow this - [example](https://engineering.nike.com/spark-expectations/1.0.0/examples/) | +| with_expectations decorator | remains same | New arguments are added. Please follow this - [example](https://engineering.nike.com/spark-expectations/1.0.0/examples/) | +| WrappedDataFrameWriter | Doesn't exist | This is new and users need to provider the writer object to record the spark conf that need to be used while writing - [example](https://engineering.nike.com/spark-expectations/1.0.0/examples/) | \ No newline at end of file diff --git a/docs/configurations/rules.md b/docs/configurations/rules.md index b59dcefd..a07355e3 100644 --- a/docs/configurations/rules.md +++ b/docs/configurations/rules.md @@ -4,34 +4,34 @@ Please find the different types of possible expectations #### Possible Row Data Quality Expectations -| rule_description | rule_type | tag | rule_expectation | -| :------------------| :-----------: | :-----: | ------------------: | -| Expect that the values in the column should not be null/empty | null_validation | completeness | ```[col_name] is not null``` | -| Ensure that the primary key values are unique and not duplicated | primary_key_validation| uniqueness | ```count(*) over(partition by [primary_key_or_combination_of_primary_key] order by 1)=1 ```| -| Perform a thorough check to make sure that there are no duplicate values, if there are duplicates preserve one row into target | complete_duplicate_validation | uniqueness | ```row_number() over(partition by [all_the_column_in_dataset_b_ comma_separated] order by 1)=1```| -| Verify that the date values are in the correct format | date_format_validation |validity |```to_date([date_col_name], '[mention_expected_date_format]') is not null``` | -| Verify that the date values are in the correct format using regex | date_format_validation_with_regex | validity | ```[date_col_name] rlike '[regex_format_of_date]'``` | -| Expect column value is date parseable | expect_column_values_to_be_date_parseable | validity | ```try_cast([date_col_name] as date)``` | -| Verify values in a column to conform to a specified regular expression pattern | expect_column_values_to_match_regex| validity | ```[col_name] rlike '[regex_format]'``` | -| Verify values in a column to not conform to a specified regular expression pattern | expect_column_values_to_not_match_regex| validity | ```[col_name] not rlike '[regex_format]'``` | -| Verify values in a column to match regex in list | expect_column_values_to_match_regex_list | validity | ```[col_name] not rlike '[regex format1]' or [col_name] not rlike '[regex_format2]' or [col_name] not rlike '[regex_format3]'``` | -| Expect the values in a column to belong to a specified set | expect_column_values_to_be_in_set | accuracy | ```[col_name] in ([values_in_comma_separated])```| -| Expect the values in a column not to belong to a specified set| expect_column_values_to_be_not_in_set |accuracy | ```[col_name] not in ([values_in_comma_separated])``` | -| Expect the values in a column to fall within a defined range | expect_column_values_to_be_in_range | accuracy | ```[col_name] between [min_threshold] and [max_threshold]``` | -| Expect the lengths of the values in a column to be within a specified range| expect_column_value_lengths_to_be_between | accuracy | ```length([col_name]) between [min_threshold] and [max_threshold]``` | -| Expect the lengths of the values in a column to be equal to a certain value | expect_column_value_lengths_to_be_equal | accuracy | ```length([col_name])=[threshold]``` | -| Expect values in the column to exceed a certain limit | expect_column_value_to_be_greater_than | accuracy| ```[col_name] > [threshold_value]``` | -| Expect values in the column not to exceed a certain limit| expect_column_value_to_be_lesser_than | accuracy | ```[col_name] < [threshold_value]``` | -| Expect values in the column to be equal to or exceed a certain limit | expect_column_value_greater_than_equal | accuracy | ```[col_name] >= [threshold_value]``` | -| Expect values in the column to be equal to or not exceed a certain limit | expect_column_value_lesser_than_equal | accuracy | ```[col_name] <= [threshold_value]``` | +| rule_description | category | tag | rule_expectation | +| :------------------|:------------------------------------------------:| :-----: | ------------------: | +| Expect that the values in the column should not be null/empty | null_validation | completeness | ```[col_name] is not null``` | +| Ensure that the primary key values are unique and not duplicated | primary_key_validation | uniqueness | ```count(*) over(partition by [primary_key_or_combination_of_primary_key] order by 1)=1 ```| +| Perform a thorough check to make sure that there are no duplicate values, if there are duplicates preserve one row into target | complete_duplicate_validation | uniqueness | ```row_number() over(partition by [all_the_column_in_dataset_b_ comma_separated] order by 1)=1```| +| Verify that the date values are in the correct format | date_format_validation |validity |```to_date([date_col_name], '[mention_expected_date_format]') is not null``` | +| Verify that the date values are in the correct format using regex | date_format_validation_with_regex | validity | ```[date_col_name] rlike '[regex_format_of_date]'``` | +| Expect column value is date parseable | expect_column_values_to_be_date_parseable | validity | ```try_cast([date_col_name] as date)``` | +| Verify values in a column to conform to a specified regular expression pattern | expect_column_values_to_match_regex | validity | ```[col_name] rlike '[regex_format]'``` | +| Verify values in a column to not conform to a specified regular expression pattern | expect_column_values_to_not_match_regex | validity | ```[col_name] not rlike '[regex_format]'``` | +| Verify values in a column to match regex in list | expect_column_values_to_match_regex_list | validity | ```[col_name] not rlike '[regex format1]' or [col_name] not rlike '[regex_format2]' or [col_name] not rlike '[regex_format3]'``` | +| Expect the values in a column to belong to a specified set | expect_column_values_to_be_in_set | accuracy | ```[col_name] in ([values_in_comma_separated])```| +| Expect the values in a column not to belong to a specified set| expect_column_values_to_be_not_in_set |accuracy | ```[col_name] not in ([values_in_comma_separated])``` | +| Expect the values in a column to fall within a defined range | expect_column_values_to_be_in_range | accuracy | ```[col_name] between [min_threshold] and [max_threshold]``` | +| Expect the lengths of the values in a column to be within a specified range| expect_column_value_lengths_to_be_between | accuracy | ```length([col_name]) between [min_threshold] and [max_threshold]``` | +| Expect the lengths of the values in a column to be equal to a certain value | expect_column_value_lengths_to_be_equal | accuracy | ```length([col_name])=[threshold]``` | +| Expect values in the column to exceed a certain limit | expect_column_value_to_be_greater_than | accuracy| ```[col_name] > [threshold_value]``` | +| Expect values in the column not to exceed a certain limit| expect_column_value_to_be_lesser_than | accuracy | ```[col_name] < [threshold_value]``` | +| Expect values in the column to be equal to or exceed a certain limit | expect_column_value_greater_than_equal | accuracy | ```[col_name] >= [threshold_value]``` | +| Expect values in the column to be equal to or not exceed a certain limit | expect_column_value_lesser_than_equal | accuracy | ```[col_name] <= [threshold_value]``` | | Expect values in column A to be greater than values in column B | expect_column_pair_values_A_to_be_greater_than_B | accuracy | ```[col_A] > [col_B]``` | -| Expect values in column A to be lesser than values in column B | expect_column_pair_values_A_to_be_lesser_than_B | accuracy | ```[col_A] < [col_B]``` | -| Expect values in column A to be greater than or equals to values in column B | expect_column_A_to_be_greater_than_B | accuracy | ```[col_A] >= [col_B]``` | -| Expect values in column A to be lesser than or equals to values in column B | expect_column_A_to_be_lesser_than_or_equals_B |accuracy | ```[col_A] <= [col_B]``` | -| Expect the sum of values across multiple columns to be equal to a certain value | expect_multicolumn_sum_to_equal | accuracy | ```[col_1] + [col_2] + [col_3] = [threshold_value]``` | -| Expect sum of values in each category equals certain value | expect_sum_of_value_in_subset_equal | accuracy | ```sum([col_name]) over(partition by [category_col] order by 1)``` | -| Expect count of values in each category equals certain value | expect_count_of_value_in_subset_equal | accuracy | ```count(*) over(partition by [category_col] order by 1)``` | -| Expect distinct value in each category exceeds certain range | expect_distinct_value_in_subset_exceeds | accuracy | ```count(distinct [col_name]) over(partition by [category_col] order by 1)``` | +| Expect values in column A to be lesser than values in column B | expect_column_pair_values_A_to_be_lesser_than_B | accuracy | ```[col_A] < [col_B]``` | +| Expect values in column A to be greater than or equals to values in column B | expect_column_A_to_be_greater_than_B | accuracy | ```[col_A] >= [col_B]``` | +| Expect values in column A to be lesser than or equals to values in column B | expect_column_A_to_be_lesser_than_or_equals_B |accuracy | ```[col_A] <= [col_B]``` | +| Expect the sum of values across multiple columns to be equal to a certain value | expect_multicolumn_sum_to_equal | accuracy | ```[col_1] + [col_2] + [col_3] = [threshold_value]``` | +| Expect sum of values in each category equals certain value | expect_sum_of_value_in_subset_equal | accuracy | ```sum([col_name]) over(partition by [category_col] order by 1)``` | +| Expect count of values in each category equals certain value | expect_count_of_value_in_subset_equal | accuracy | ```count(*) over(partition by [category_col] order by 1)``` | +| Expect distinct value in each category exceeds certain range | expect_distinct_value_in_subset_exceeds | accuracy | ```count(distinct [col_name]) over(partition by [category_col] order by 1)``` | diff --git a/docs/delta.md b/docs/delta.md new file mode 100644 index 00000000..0ff04883 --- /dev/null +++ b/docs/delta.md @@ -0,0 +1,82 @@ +### Example - Write to Delta + +Setup SparkSession for Delta Lake to test in your local environment. Configure accordingly for higher environments. +Refer to Examples in [base_setup.py](../spark_expectations/examples/base_setup.py) and +[delta.py](../spark_expectations/examples/sample_dq_delta.py) + +```python title="spark_session" +from pyspark.sql import SparkSession + +builder = ( + SparkSession.builder.config( + "spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension" + ) + .config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + .config("spark.sql.warehouse.dir", "/tmp/hive/warehouse") + .config("spark.driver.extraJavaOptions", "-Dderby.system.home=/tmp/derby") + .config("spark.jars.ivy", "/tmp/ivy2") + ) +spark = builder.getOrCreate() +``` + +Below is the configuration that can be used to run SparkExpectations and write to DeltaLake + +```python title="delta_write" +import os +from pyspark.sql import DataFrame +from spark_expectations.core.expectations import ( + SparkExpectations, + WrappedDataFrameWriter, +) +from spark_expectations.config.user_config import Constants as user_config + +writer = WrappedDataFrameWriter().mode("append").format("delta") + +se: SparkExpectations = SparkExpectations( + product_id="your_product", + rules_df=spark.table("dq_spark_local.dq_rules"), + stats_table="dq_spark_local.dq_stats", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + stats_streaming_options={user_config.se_enable_streaming: False} +) + +# Commented fields are optional or required when notifications are enabled +user_conf = { + user_config.se_notifications_enable_email: False, + # user_config.se_notifications_email_smtp_host: "mailhost.com", + # user_config.se_notifications_email_smtp_port: 25, + # user_config.se_notifications_email_from: "", + # user_config.se_notifications_email_to_other_mail_id: "", + # user_config.se_notifications_email_subject: "spark expectations - data quality - notifications", + user_config.se_notifications_enable_slack: False, + # user_config.se_notifications_slack_webhook_url: "", + # user_config.se_notifications_on_start: True, + # user_config.se_notifications_on_completion: True, + # user_config.se_notifications_on_fail: True, + # user_config.se_notifications_on_error_drop_exceeds_threshold_breach: True, + # user_config.se_notifications_on_error_drop_threshold: 15, +} + + +@se.with_expectations( + target_table="dq_spark_local.customer_order", + write_to_table=True, + user_conf=user_conf, + target_table_view="order", +) +def build_new() -> DataFrame: + _df_order: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) + ) + _df_order.createOrReplaceTempView("order") + + return _df_order +``` diff --git a/docs/examples.md b/docs/examples.md index beb00b6d..41aae9c1 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -6,7 +6,7 @@ In order to establish the global configuration parameter for DQ Spark Expectatio ```python from spark_expectations.config.user_config import Constants as user_config -se_global_spark_Conf = { +se_user_conf = { user_config.se_notifications_enable_email: False, # (1)! user_config.se_notifications_email_smtp_host: "mailhost.com", # (2)! user_config.se_notifications_email_smtp_port: 25, # (3)! @@ -49,16 +49,16 @@ from typing import Dict, Union from spark_expectations.config.user_config import Constants as user_config stats_streaming_config_dict: Dict[str, Union[bool, str]] = { - user_config.se_enable_streaming: True, # (1)! - user_config.secret_type: "databricks", # (2)! - user_config.dbx_workspace_url : "https://workspace.cloud.databricks.com", # (3)! - user_config.dbx_secret_scope: "sole_common_prod", # (4)! - user_config.dbx_kafka_server_url: "se_streaming_server_url_secret_key", # (5)! - user_config.dbx_secret_token_url: "se_streaming_auth_secret_token_url_key", # (6)! - user_config.dbx_secret_app_name: "se_streaming_auth_secret_appid_key", # (7)! - user_config.dbx_secret_token: "se_streaming_auth_secret_token_key", # (8)! - user_config.dbx_topic_name: "se_streaming_topic_name", # (9)! - } + user_config.se_enable_streaming: True, # (1)! + user_config.secret_type: "databricks", # (2)! + user_config.dbx_workspace_url : "https://workspace.cloud.databricks.com", # (3)! + user_config.dbx_secret_scope: "sole_common_prod", # (4)! + user_config.dbx_kafka_server_url: "se_streaming_server_url_secret_key", # (5)! + user_config.dbx_secret_token_url: "se_streaming_auth_secret_token_url_key", # (6)! + user_config.dbx_secret_app_name: "se_streaming_auth_secret_appid_key", # (7)! + user_config.dbx_secret_token: "se_streaming_auth_secret_token_key", # (8)! + user_config.dbx_topic_name: "se_streaming_topic_name", # (9)! +} ``` 1. The `user_config.se_enable_streaming` parameter is used to control the enabling or disabling of Spark Expectations (SE) streaming functionality. When enabled, SE streaming stores the statistics of every batch run into Kafka. @@ -78,16 +78,16 @@ from typing import Dict, Union from spark_expectations.config.user_config import Constants as user_config stats_streaming_config_dict: Dict[str, Union[bool, str]] = { - user_config.se_enable_streaming: True, # (1)! - user_config.secret_type: "databricks", # (2)! - user_config.cbs_url : "https://.cerberus.com", # (3)! - user_config.cbs_sdb_path: "cerberus_sdb_path", # (4)! - user_config.cbs_kafka_server_url: "se_streaming_server_url_secret_sdb_path", # (5)! - user_config.cbs_secret_token_url: "se_streaming_auth_secret_token_url_sdb_apth", # (6)! - user_config.cbs_secret_app_name: "se_streaming_auth_secret_appid_sdb_path", # (7)! - user_config.cbs_secret_token: "se_streaming_auth_secret_token_sdb_path", # (8)! - user_config.cbs_topic_name: "se_streaming_topic_name_sdb_path", # (9)! - } + user_config.se_enable_streaming: True, # (1)! + user_config.secret_type: "databricks", # (2)! + user_config.cbs_url : "https://.cerberus.com", # (3)! + user_config.cbs_sdb_path: "cerberus_sdb_path", # (4)! + user_config.cbs_kafka_server_url: "se_streaming_server_url_secret_sdb_path", # (5)! + user_config.cbs_secret_token_url: "se_streaming_auth_secret_token_url_sdb_apth", # (6)! + user_config.cbs_secret_app_name: "se_streaming_auth_secret_appid_sdb_path", # (7)! + user_config.cbs_secret_token: "se_streaming_auth_secret_token_sdb_path", # (8)! + user_config.cbs_topic_name: "se_streaming_topic_name_sdb_path", # (9)! +} ``` 1. The `user_config.se_enable_streaming` parameter is used to control the enabling or disabling of Spark Expectations (SE) streaming functionality. When enabled, SE streaming stores the statistics of every batch run into Kafka. @@ -100,363 +100,66 @@ stats_streaming_config_dict: Dict[str, Union[bool, str]] = { 8. The `user_config.cbs_secret_token` captures path where kafka authentication app name secret token stored in the cerberus sdb 9. The `user_config.cbs_topic_name` captures path where kafka topic name stored in the cerberus sdb -```python -from spark_expectations.core.expectations import SparkExpectations - -# product_id should match with the "product_id" in the rules table -se: SparkExpectations = SparkExpectations(product_id="your-products-id", stats_streaming_options=stats_streaming_config_dict) # (1)! -``` - - -1. Instantiate `SparkExpectations` class which has all the required functions for running data quality rules - - -#### Example 1 - -```python -from spark_expectations.config.user_config import * # (7)! - - -@se.with_expectations( # (6)! - se.reader.get_rules_from_table( # (5)! - product_rules_table="pilot_nonpub.dq.dq_rules", # (1)! - table_name="pilot_nonpub.dq_employee.employee", # (2)! - dq_stats_table_name="pilot_nonpub.dq.dq_stats" # (3)! - ), - write_to_table=True, # (4)! - write_to_temp_table=True, # (8)! - row_dq=True, # (9)! - agg_dq={ # (10)! - user_config.se_agg_dq: True, # (11)! - user_config.se_source_agg_dq: True, # (12)! - user_config.se_final_agg_dq: True, # (13)! - }, - query_dq={ # (14)! - user_config.se_query_dq: True, # (15)! - user_config.se_source_query_dq: True, # (16)! - user_config.se_final_query_dq: True, # (17)! - user_config.se_target_table_view: "order", # (18)! - }, - spark_conf=se_global_spark_Conf, # (19)! - -) -def build_new() -> DataFrame: - _df_order: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) - ) - _df_order.createOrReplaceTempView("order") # (20)! - - _df_product: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/product.csv")) - ) - _df_product.createOrReplaceTempView("product") # (20)! - - _df_customer: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/customer.csv")) - ) - - _df_customer.createOrReplaceTempView("customer") # (20)! - - return _df_order # (21)! -``` - - -1. Provide the full table name of the table which contains the rules -2. Provide the table name using which the `_error` table will be created, which contains all the failed records. - Note if you are also wanting to write the data using `write_df`, then the table_name provided to both the functions - should be same -3. Provide the full table name where the stats will be written into -4. Use this argument to write the final data into the table. By default, it is False. - This is optional, if you just want to run the data quality checks. - A good example will be a staging table or temporary view. -5. This functions reads the rules from the table and return them as a dict, which is an input to the `with_expectations` function -6. This is the decorator that helps us run the data quality rules. After running the rules the results will be written into `_stats` table and `error` table -7. import necessary configurable variables from `user_config` package for the specific functionality to configure in spark-expectations -8. Use this argument to write the input dataframe into the temp table, so that it breaks the spark plan and might speed - up the job in cases of complex dataframe lineage -9. The argument row_dq is optional and enables the conducting of row-based data quality checks. By default, this - argument is set to True, however, if desired, these checks can be skipped by setting the argument to False. -10. The `agg_dq` argument is a dictionary that is used to gather different settings and options for the purpose of configuring the `agg_dq` -11. The argument `se_agg_dq` is utilized to activate the aggregate data quality check, and its default setting is True. -12. The `se_source_agg_dq` argument is optional and enables the conducting of aggregate-based data quality checks on the - source data. By default, this argument is set to True, and this option depends on the `agg_dq` value. - If desired, these checks can be skipped by setting the source_agg_dq argument to False. -13. This optional argument `se_final_agg_dq` allows to perform agg-based data quality checks on final data, with the - default setting being `True`, which depended on `row_agg` and `agg_dq`. skip these checks by setting argument to `False` -14. The `query_dq` argument is a dictionary that is used to gather different settings and options for the purpose of configuring the `query_dq` -15. The argument `se_query_dq` is utilized to activate the aggregate data quality check, and its default setting is True. -16. The `se_source_query_dq` argument is optional and enables the conducting of query-based data quality checks on the - source data. By default, this argument is set to True, and this option depends on the `agg_dq` value. - If desired, these checks can be skipped by setting the source_agg_dq argument to False. -17. This optional argument `se_final_query_dq` allows to perform query_based data quality checks on final data, with the - default setting being `True`, which depended on `row_agg` and `agg_dq`. skip these checks by setting argument to `False` -18. The parameter `se_target_table_view` can be provided with the name of a view that represents the target validated dataset for implementation of `query_dq` on the clean dataset from `row_dq` -19. The `spark_conf` parameter is utilized to gather all the configurations that are associated with notifications -20. View registration can be utilized when implementing `query_dq` expectations. -21. Returning a dataframe is mandatory for the `spark_expectations` to work, if we do not return a dataframe - then an exceptionm will be raised - - -#### Example 2 - -```python -@se.with_expectations( # (1)! - se.reader.get_rules_from_table( - product_rules_table="pilot_nonpub.dq.dq_rules", - target_table_name="pilot_nonpub.customer_order", - dq_stats_table_name="pilot_nonpub.dq.dq_stats") - ), - row_dq=True # (2)! -) -def build_new() -> DataFrame: - _df: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/employee.csv")) - ) - return df -``` - -1. Conduct only row-based data quality checks while skipping the aggregate data quality checks -2. Disabled the aggregate data quality checks - -#### Example 3 - -```python -@se.with_expectations( # (1)! - se.reader.get_rules_from_table( - product_rules_table="pilot_nonpub.dq.dq_rules", - target_table_name="pilot_nonpub.customer_order", - dq_stats_table_name="pilot_nonpub.dq.dq_stats" - ), - row_dq=False, # (2)! - agg_dq={ - user_config.se_agg_dq: True, - user_config.se_source_agg_dq: True, - user_config.se_final_agg_dq: False, - } - -) -def build_new() -> DataFrame: - _df: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/employee.csv")) - ) - return df -``` - -1. Perform only aggregate-based data quality checks while avoiding both row-based data quality checks and aggregate data - quality checks on the validated dataset, since row validation has not taken place -2. Disabled the row data quality checks - -#### Example 4 - -```python -@se.with_expectations( # (1)! - se.reader.get_rules_from_table( - product_rules_table="pilot_nonpub.dq.dq_rules", - target_table_name="pilot_nonpub.customer_order", - dq_stats_table_name="pilot_nonpub.dq.dq_stats" - ), - row_dq=True, - query_dq={ # (2)! - user_config.se_query_dq: True, - user_config.se_source_query_dq: True, - user_config.se_final_query_dq: True, - user_config.se_target_table_view: "order", - }, - -) -def build_new() -> DataFrame: - _df_order: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) - ) - _df_order.createOrReplaceTempView("order") - _df_product: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/product.csv")) - ) - _df_product.createOrReplaceTempView("product") - - _df_customer: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/customer.csv")) - ) - - _df_customer.createOrReplaceTempView("customer") - - return _df_order -``` - -1. Conduct row-based and query-based data quality checks only on the source and target dataset, while skipping the aggregate - data quality checks on the validated dataset -2. Enabled the query data quality checks - -#### Example 5 +You can disable the streaming functionality by setting the `user_config.se_enable_streaming` parameter to `False` ```python -@se.with_expectations( # (1)! - se.reader.get_rules_from_table( - product_rules_table="pilot_nonpub.dq.dq_rules", - target_table_name="pilot_nonpub.customer_order", - dq_stats_table_name="pilot_nonpub.dq.dq_stats" - ), - row_dq=True, - agg_dq={ # (10)! - user_config.user_configse_agg_dq: True, - user_config.se_source_agg_dq: True, - user_config.se_final_agg_dq: False, # (2)! - }, +from typing import Dict, Union +from spark_expectations.config.user_config import Constants as user_config -) -def build_new() -> DataFrame: - _df_order: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) - ) - return _df_order +stats_streaming_config_dict: Dict[str, Union[bool, str]] = { + user_config.se_enable_streaming: False, # (1)! +} ``` -1. Conduct row-based and aggregate-based data quality checks only on the source dataset, while skipping the aggregate - data quality checks on the validated dataset -2. Disabled the final aggregate data quality quality checks - - -#### Example 6 +1. The `user_config.se_enable_streaming` parameter is used to control the enabling or disabling of Spark Expectations (SE) streaming functionality. When enabled, SE streaming stores the statistics of every batch run into Kafka. ```python -import os +from spark_expectations.core.expectations import SparkExpectations -@se.with_expectations( - se.reader.get_rules_from_table( - product_rules_table="pilot_nonpub.dq.dq_rules", - target_table_name="pilot_nonpub.customer_order", - dq_stats_table_name="pilot_nonpub.dq.dq_stats" - ), - spark_conf=se_global_spark_Conf, # (2)! - -) -def build_new() -> DataFrame: - _df_order: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) - ) - return _df_order +# product_id should match with the "product_id" in the rules table +se: SparkExpectations = SparkExpectations( + product_id="your-products-id", + stats_streaming_options=stats_streaming_config_dict) # (1)! ``` -1. There are four types of notifications: notification_on_start, notification_on_completion, notification_on_fail and notification_on_error_threshold_breach. - Enable notifications for all four stages by setting the values to `True` -2. To provide the absolute file path for a configuration variable that holds information regarding notifications, use the - decalared global variable, `se_global_spark_Conf` - -#### Example 7 - -```python -@se.with_expectations( # (1)! - se.reader.get_rules_from_table( - product_rules_table="pilot_nonpub.dq.dq_rules", - target_table_name="pilot_nonpub.customer_order", - dq_stats_table_name="pilot_nonpub.dq.dq_stats" - ), - row_dq=False, - agg_dq={ - user_config.se_agg_dq: False, - user_config.se_source_agg_dq: False, - user_config.se_final_agg_dq: True, - }, - query_dq={ - user_config.se_query_dq: False, - user_config.se_source_query_dq: True, - user_config.se_final_query_dq: True, - user_config.se_target_table_view: "order", - }, -) -def build_new() -> DataFrame: - _df_order: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) - ) - return _df_order -``` - -1. Below combination of `row_dq, agg_dq, source_agg_dq, final_agg_dq, query_dq, source_query_dq and final_query_dq` skips the data quality checks because - source_agg_dq depends on agg_dq and final_agg_dq depends on row_dq and agg_dq +1. Instantiate `SparkExpectations` class which has all the required functions for running data quality rules -#### Example 8 +#### Example 1 ```python -@se.with_expectations( # (1)! - se.reader.get_rules_from_table( - product_rules_table="pilot_nonpub.dq.dq_rules", - target_table_name="pilot_nonpub.customer_order", - dq_stats_table_name="pilot_nonpub.dq.dq_stats", - actions_if_failed=["drop", "ignore"] # (1)! - ) +from spark_expectations.core.expectations import SparkExpectations, WrappedDataFrameWriter + +writer = WrappedDataFrameWriter().mode("append").format("delta") # (1)! +se = SparkExpectations( # (10)! + product_id="your_product", # (11)! + rules_df=spark.table("dq_spark_local.dq_rules"), # (12)! + stats_table="dq_spark_local.dq_stats", # (13)! + stats_table_writer=writer, # (14)! + target_and_error_table_writer=writer, # (15)! + debugger=False, # (16)! + # stats_streaming_options={user_config.se_enable_streaming: False}, # (17)! ) -def build_new() -> DataFrame: - _df_order: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) - ) - return _df_order -``` - -1. By default `action_if_failed` contains ["fail", "drop", "ignore"], but if we want to run only rules which has a - particular action then we can pass them as list shown in the example - - -#### Example 9 - -```python -@se.with_expectations( # (1)! - se.reader.get_rules_from_table( - product_rules_table="pilot_nonpub.dq.dq_rules", - target_table_name="pilot_nonpub.customer_order", - dq_stats_table_name="pilot_nonpub.dq.dq_stats", - actions_if_failed=["drop", "ignore"] # (1)! - ), - row_dq=True, # (2)! - agg_dq={ - user_config.se_agg_dq: True, - user_config.se_source_agg_dq: True, - user_config.se_final_agg_dq: True, - }, - query_dq={ - user_config.se_query_dq: True, - user_config.se_source_query_dq: True, - user_config.se_final_query_dq: True, - user_config.se_target_table_view: "order", - } - +@se.with_expectations( # (2)! + write_to_table=True, # (3)! + write_to_temp_table=True, # (4)! + user_conf=se_user_conf, # (5)! + target_table_view="order", # (6)! + target_and_error_table_writer=writer, # (7)! ) def build_new() -> DataFrame: - _df_order: DataFrame = ( + _df_order: DataFrame = ( spark.read.option("header", "true") .option("inferSchema", "true") .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) ) - _df_order.createOrReplaceTempView("order") + _df_order.createOrReplaceTempView("order") _df_product: DataFrame = ( spark.read.option("header", "true") .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/product.csv")) + .csv(os.path.join(os.path.dirname(__file__), "resources/product.csv")) ) _df_product.createOrReplaceTempView("product") @@ -466,41 +169,26 @@ def build_new() -> DataFrame: .csv(os.path.join(os.path.dirname(__file__), "resources/customer.csv")) ) - _df_customer.createOrReplaceTempView("customer") - - return _df_order - -``` - -1. The default options for the action_if_failed field are ["fail", "drop", or "ignore"], but you can specify which of - these actions to run by providing a list of the desired actions in the example when selecting which data quality rules - set to apply -2. Data quality rules will only be applied if they have ["drop" or "ignore"] specified in the action_if_failed field - + _df_customer.createOrReplaceTempView("customer") # (8)! -#### Example 10 - -```python -@se.with_expectations( - se.reader.get_rules_from_table( - product_rules_table="pilot_nonpub.dq.dq_rules", - target_table_name="pilot_nonpub.customer_order", - dq_stats_table_name="pilot_nonpub.dq.dq_stats" - ), - spark_conf={"spark.files.maxPartitionBytes": "134217728"}, # (1)! - options={"mode": "overwrite", "partitionBy": "order_month", - "overwriteSchema": "true"}, # (2)! - options_error_table={"partition_by": "id"} # (3)! -) -def build_new() -> DataFrame: - _df_order: DataFrame = ( - spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) - ) - return _df_order + return _df_order # (9)! ``` -1. Provide the optional `spark_conf` if needed, this is used while writing the data into the `final` and `error` table along with notification related configurations -2. Provide the optional `options` if needed, this is used while writing the data into the `final` table -3. Provide the optional `options_error_table` if needed, this is used while writing the data into the `error` table \ No newline at end of file +1. The `WrappedDataFrameWriter` class is used to wrap the `DataFrameWriter` class and add additional functionality to it +2. The `@se.with_expectations` decorator is used to run the data quality rules +3. The `write_to_table` parameter is used to write the final data into the table. By default, it is False. This is optional, if you just want to run the data quality checks. A good example will be a staging table or temporary view. +4. The `write_to_temp_table` parameter is used to write the input dataframe into the temp table, so that it breaks the spark plan and might speed up the job in cases of complex dataframe lineage +5. The `user_conf` parameter is utilized to gather all the configurations that are associated with notifications. There are four types of notifications: notification_on_start, notification_on_completion, notification_on_fail and notification_on_error_threshold_breach. + Enable notifications for all four stages by setting the values to `True`. By default, all four stages are set to `False` +6. The `target_table_view` parameter is used to provide the name of a view that represents the target validated dataset for implementation of `query_dq` on the clean dataset from `row_dq` +7. The `target_and_error_table_writer` parameter is used to write the final data into the table. By default, it is False. This is optional, if you just want to run the data quality checks. A good example will be a staging table or temporary view. +8. View registration can be utilized when implementing `query_dq` expectations. +9. Returning a dataframe is mandatory for the `spark_expectations` to work, if we do not return a dataframe - then an exceptionm will be raised +10. Instantiate `SparkExpectations` class which has all the required functions for running data quality rules +11. The `product_id` parameter is used to specify the product ID of the data quality rules. This has to be a unique value +12. The `rules_df` parameter is used to specify the dataframe that contains the data quality rules +13. The `stats_table` parameter is used to specify the table name where the statistics will be written into +14. The `stats_table_writer` takes in the configuration that need to be used to write the stats table using pyspark +15. The `target_and_error_table_writer` takes in the configuration that need to be used to write the target and error table using pyspark +16. The `debugger` parameter is used to enable the debugger mode +17. The `stats_streaming_options` parameter is used to specify the configurations for streaming statistics into Kafka. To not use kafka, uncomment this. diff --git a/docs/iceberg.md b/docs/iceberg.md new file mode 100644 index 00000000..d9123aa1 --- /dev/null +++ b/docs/iceberg.md @@ -0,0 +1,88 @@ +### Example - Write to Delta + +Setup SparkSession for iceberg to test in your local environment. Configure accordingly for higher environments. +Refer to Examples in [base_setup.py](../spark_expectations/examples/base_setup.py) and +[delta.py](../spark_expectations/examples/sample_dq_iceberg.py) + +```python title="spark_session" +from pyspark.sql import SparkSession + +builder = ( + SparkSession.builder.config( + "spark.jars.packages", + "org.apache.iceberg:iceberg-spark-runtime-3.4_2.12:1.3.1", + ) + .config( + "spark.sql.extensions", + "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", + ) + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.iceberg.spark.SparkSessionCatalog", + ) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", "/tmp/hive/warehouse") + .config("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") + .config("spark.sql.catalog.local.type", "hadoop") + .config("spark.sql.catalog.local.warehouse", "/tmp/hive/warehouse") +) +spark = builder.getOrCreate() +``` + +Below is the configuration that can be used to run SparkExpectations and write to DeltaLake + +```python title="iceberg_write" +import os +from pyspark.sql import DataFrame +from spark_expectations.core.expectations import ( + SparkExpectations, + WrappedDataFrameWriter, +) +from spark_expectations.config.user_config import Constants as user_config + +writer = WrappedDataFrameWriter().mode("append").format("iceberg") + +se: SparkExpectations = SparkExpectations( + product_id="your_product", + rules_df=spark.sql("select * from dq_spark_local.dq_rules"), + stats_table="dq_spark_local.dq_stats", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + stats_streaming_options={user_config.se_enable_streaming: False}, +) + +# Commented fields are optional or required when notifications are enabled +user_conf = { + user_config.se_notifications_enable_email: False, + # user_config.se_notifications_email_smtp_host: "mailhost.com", + # user_config.se_notifications_email_smtp_port: 25, + # user_config.se_notifications_email_from: "", + # user_config.se_notifications_email_to_other_mail_id: "", + # user_config.se_notifications_email_subject: "spark expectations - data quality - notifications", + user_config.se_notifications_enable_slack: False, + # user_config.se_notifications_slack_webhook_url: "", + # user_config.se_notifications_on_start: True, + # user_config.se_notifications_on_completion: True, + # user_config.se_notifications_on_fail: True, + # user_config.se_notifications_on_error_drop_exceeds_threshold_breach: True, + # user_config.se_notifications_on_error_drop_threshold: 15, +} + + +@se.with_expectations( + target_table="dq_spark_local.customer_order", + write_to_table=True, + user_conf=user_conf, + target_table_view="order", +) +def build_new() -> DataFrame: + _df_order: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) + ) + _df_order.createOrReplaceTempView("order") + + return _df_order +``` diff --git a/docs/index.md b/docs/index.md index 1d233817..ea7fce06 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,6 +4,17 @@ Taking inspiration from DLT - data quality expectations: Spark-Expectations is b run using decorator pattern while the spark job is in flight and Additionally, the framework able to perform data quality checks when the data is at rest. +### Features Of Spark Expectations + +Please find the spark-expectations flow and feature diagrams below + +

+

+ +

+

+ + ## Concept Most of the data quality tools do the data quality checks or data validation on a table at rest and provide metrics in different forms. `While the existing tools are good to do profiling and provide metrics, below are the problems that we @@ -54,6 +65,3 @@ final table or not? Below are the heirarchy of checks that happens? recorded in the `_stats` table and the job will be considered a failure. However, if none of the failed rules has an _action_if_failed_as_fail_, then summary of the aggregated rules' metadata will still be collected in the `_stats` table for failed aggregated and query dq rules. - - -Please find the spark-expectations flow and feature diagrams [here](./se_diagrams/spark_expectations_flow_and_feature.pptx) \ No newline at end of file diff --git a/mike-mkdocs_vbm2uyw.yml b/mike-mkdocs_vbm2uyw.yml deleted file mode 100644 index afbf3655..00000000 --- a/mike-mkdocs_vbm2uyw.yml +++ /dev/null @@ -1,146 +0,0 @@ -site_name: Spark-Expectations -site_author: Ashok Singamaneni -copyright: Nike -site_description: Spark-Expectations is a framework for running data quality rules - inflight of spark job. -site_url: https://github.com/Nike-Inc/spark-expectations/ -repo_name: nike/spark-expectations -repo_url: https://github.com/Nike-Inc/spark-expectations/ -theme: - name: material - palette: - - scheme: default - primary: indigo - accent: indigo - toggle: - icon: material/brightness-7 - name: Switch to dark mode - - scheme: slate - primary: indigo - accent: indigo - toggle: - icon: material/brightness-4 - name: Switch to light mode - features: - - content.code.annotate - - content.tooltips - - navigation.expand - - navigation.indexes - - navigation.instant - - navigation.tabs - - navigation.tabs.sticky - - navigation.top - - navigation.tracking - - search.highlight - - search.share - - search.suggest - - toc.follow - font: - text: Roboto - code: Roboto Mono - language: en -nav: -- Home: index.md -- Setup: getting-started/setup.md -- Python API: - - Core: - Init: api/core_init.md - Context: api/context.md - Exceptions: api/exceptions.md - Expectations: api/expectations.md - - Sinks: - Init: api/sinks_init.md - Base_Writer_Plugin: api/base_sink_plugin.md - Delta_Sink_Plugin: api/delta_sink_plugin.md - Kafka_Sink_plugin: api/kafka_sink_plugin.md - Utils: - - Writer: api/writer.md - - Sink_Decorater: api/sinks_decorater - - Notification: - Init: api/notifications_init.md - Base_Notification_plugin: api/base_notification_plugin.md - Email_Notification_plugin: api/email_plugin.md - Slack_Notification_plugin: api/slack_plugin.md - Push: - - Notification_Decorater: api/notifications_decorater - - Utils: - Actions: api/actions.md - Reader: api/reader.md - Regulate_flow: api/regulate_flow.md - Udf: api/udf.md - - Examples: - Base_Setup: api/base_setup.md - Sample_DQ: api/sample_dq.md -- Examples: - Rules: configurations/rules.md - Configure_Rules: configurations/configure_rules.md - Initialization_Examples: examples.md -plugins: -- mike -- search: - lang: en -- mkdocstrings: - handlers: - python: - paths: - - spark_expectations - options: - show_source: true - show_root_heading: false - heading_level: 1 - merge_init_into_class: true - show_if_no_docstring: true - show_root_full_path: true - show_root_members_full_path: true - show_root_toc_entry: false - show_category_heading: true - show_signature_annotations: true - separate_signature: false -markdown_extensions: -- abbr -- admonition -- mkdocs-click -- attr_list -- def_list -- footnotes -- md_in_html -- toc: - permalink: true -- pymdownx.arithmatex: - generic: true -- pymdownx.betterem: - smart_enable: all -- pymdownx.caret -- pymdownx.details -- pymdownx.emoji: - emoji_generator: !!python/name:materialx.emoji.to_svg '' - emoji_index: !!python/name:materialx.emoji.twemoji '' -- pymdownx.highlight: - anchor_linenums: true -- pymdownx.inlinehilite -- pymdownx.keys -- pymdownx.magiclink: - repo_url_shorthand: true - user: squidfunk - repo: mkdocs-material -- pymdownx.mark -- pymdownx.smartsymbols -- pymdownx.superfences: - custom_fences: - - name: mermaid - class: mermaid - format: !!python/name:pymdownx.superfences.fence_code_format '' -- pymdownx.tabbed: - alternate_style: true -- pymdownx.tasklist: - custom_checkbox: true -- pymdownx.tilde -watch: -- spark_expectations -extra_css: -- css/custom.css -extra: - generator: false - version: - provider: mike - default: latest diff --git a/mkdocs.yml b/mkdocs.yml index c2140835..18afe0f8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,9 +49,12 @@ nav: - Adoption_Guide: - comparision: configurations/adoption_versions_comparsion.md - Examples: + Understand Args: examples.md + Delta: delta.md + BigQuery: bigquery.md + Iceberg: iceberg.md Rules: configurations/rules.md - Configure_Rules: configurations/configure_rules.md - Initialization_Examples: examples.md + Configure Rules: configurations/configure_rules.md - Python API: - Core: Init: api/core_init.md diff --git a/poetry.lock b/poetry.lock index b8610ec6..931f8e5e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -134,13 +134,13 @@ files = [ [[package]] name = "cfgv" -version = "3.3.1" +version = "3.4.0" description = "Validate configuration and produce human readable error messages." optional = false -python-versions = ">=3.6.1" +python-versions = ">=3.8" files = [ - {file = "cfgv-3.3.1-py2.py3-none-any.whl", hash = "sha256:c6a0883f3917a037485059700b9e75da2464e6c27051014ad85ba6aaa5884426"}, - {file = "cfgv-3.3.1.tar.gz", hash = "sha256:f5a830efb9ce7a445376bb66ec94c638a9787422f96264c98edc6bdeed8ab736"}, + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, ] [[package]] @@ -229,13 +229,13 @@ files = [ [[package]] name = "click" -version = "8.1.6" +version = "8.1.7" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" files = [ - {file = "click-8.1.6-py3-none-any.whl", hash = "sha256:fa244bb30b3b5ee2cae3da8f55c9e5e0c0e86093306301fb418eb9dc40fbded5"}, - {file = "click-8.1.6.tar.gz", hash = "sha256:48ee849951919527a045bfe3bf7baa8a959c423134e1a5b98c05c20ba75a1cbd"}, + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, ] [package.dependencies] @@ -326,21 +326,6 @@ files = [ {file = "cyclic-1.0.0.tar.gz", hash = "sha256:ecddd56cb831ee3e6b79f61ecb0ad71caee606c507136867782911aa01c3e5eb"}, ] -[[package]] -name = "delta-spark" -version = "2.4.0" -description = "Python APIs for using Delta Lake with Apache Spark" -optional = false -python-versions = ">=3.6" -files = [ - {file = "delta-spark-2.4.0.tar.gz", hash = "sha256:ef776e325e80d98e3920cab982c747b094acc46599d62dfcdc9035fb112ba6a9"}, - {file = "delta_spark-2.4.0-py3-none-any.whl", hash = "sha256:7204142a97ef16367403b020d810d0c37f4ae8275b4997de4056423cf69b3a4b"}, -] - -[package.dependencies] -importlib-metadata = ">=1.0.0" -pyspark = ">=3.4.0,<3.5.0" - [[package]] name = "dill" version = "0.3.7" @@ -379,13 +364,13 @@ files = [ [[package]] name = "exceptiongroup" -version = "1.1.2" +version = "1.1.3" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" files = [ - {file = "exceptiongroup-1.1.2-py3-none-any.whl", hash = "sha256:e346e69d186172ca7cf029c8c1d16235aa0e04035e5750b4b95039e65204328f"}, - {file = "exceptiongroup-1.1.2.tar.gz", hash = "sha256:12c3e887d6485d16943a309616de20ae5582633e0a2eda17f4e10fd61c1e8af5"}, + {file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"}, + {file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"}, ] [package.extras] @@ -393,18 +378,19 @@ test = ["pytest (>=6)"] [[package]] name = "filelock" -version = "3.12.2" +version = "3.12.4" description = "A platform independent file lock." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, - {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"}, + {file = "filelock-3.12.4-py3-none-any.whl", hash = "sha256:08c21d87ded6e2b9da6728c3dff51baf1dcecf973b768ef35bcbc3447edb9ad4"}, + {file = "filelock-3.12.4.tar.gz", hash = "sha256:2e6f249f1f3654291606e046b09f1fd5eac39b360664c27f5aad072012f8bcbd"}, ] [package.extras] -docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] +docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"] +typing = ["typing-extensions (>=4.7.1)"] [[package]] name = "flake8" @@ -469,27 +455,30 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.32" +version = "3.1.36" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" files = [ - {file = "GitPython-3.1.32-py3-none-any.whl", hash = "sha256:e3d59b1c2c6ebb9dfa7a184daf3b6dd4914237e7488a1730a6d8f6f5d0b4187f"}, - {file = "GitPython-3.1.32.tar.gz", hash = "sha256:8d9b8cb1e80b9735e8717c9362079d3ce4c6e5ddeebedd0361b228c3a67a62f6"}, + {file = "GitPython-3.1.36-py3-none-any.whl", hash = "sha256:8d22b5cfefd17c79914226982bb7851d6ade47545b1735a9d010a2a4c26d8388"}, + {file = "GitPython-3.1.36.tar.gz", hash = "sha256:4bb0c2a6995e85064140d31a33289aa5dce80133a23d36fcd372d716c54d3ebf"}, ] [package.dependencies] gitdb = ">=4.0.1,<5" +[package.extras] +test = ["black", "coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mypy", "pre-commit", "pytest", "pytest-cov", "pytest-sugar", "virtualenv"] + [[package]] name = "griffe" -version = "0.32.3" +version = "0.36.2" description = "Signatures for entire Python programs. Extract the structure, the frame, the skeleton of your project, to generate API documentation or find breaking changes in your API." optional = false python-versions = ">=3.8" files = [ - {file = "griffe-0.32.3-py3-none-any.whl", hash = "sha256:d9471934225818bf8f309822f70451cc6abb4b24e59e0bb27402a45f9412510f"}, - {file = "griffe-0.32.3.tar.gz", hash = "sha256:14983896ad581f59d5ad7b6c9261ff12bdaa905acccc1129341d13e545da8521"}, + {file = "griffe-0.36.2-py3-none-any.whl", hash = "sha256:ba71895a3f5f606b18dcd950e8a1f8e7332a37f90f24caeb002546593f2e0eee"}, + {file = "griffe-0.36.2.tar.gz", hash = "sha256:333ade7932bb9096781d83092602625dfbfe220e87a039d2801259a1bd41d1c2"}, ] [package.dependencies] @@ -497,13 +486,13 @@ colorama = ">=0.4" [[package]] name = "identify" -version = "2.5.26" +version = "2.5.29" description = "File identification library for Python" optional = false python-versions = ">=3.8" files = [ - {file = "identify-2.5.26-py2.py3-none-any.whl", hash = "sha256:c22a8ead0d4ca11f1edd6c9418c3220669b3b7533ada0a0ffa6cc0ef85cf9b54"}, - {file = "identify-2.5.26.tar.gz", hash = "sha256:7243800bce2f58404ed41b7c002e53d4d22bcf3ae1b7900c2d7aefd95394bf7f"}, + {file = "identify-2.5.29-py2.py3-none-any.whl", hash = "sha256:24437fbf6f4d3fe6efd0eb9d67e24dd9106db99af5ceb27996a5f7895f24bf1b"}, + {file = "identify-2.5.29.tar.gz", hash = "sha256:d43d52b86b15918c137e3a74fff5224f60385cd0e9c38e99d07c257f02f151a5"}, ] [package.extras] @@ -907,17 +896,17 @@ python-legacy = ["mkdocstrings-python-legacy (>=0.2.1)"] [[package]] name = "mkdocstrings-python" -version = "1.3.0" +version = "1.7.0" description = "A Python handler for mkdocstrings." optional = false python-versions = ">=3.8" files = [ - {file = "mkdocstrings_python-1.3.0-py3-none-any.whl", hash = "sha256:36c224c86ab77e90e0edfc9fea3307f7d0d245dd7c28f48bbb2203cf6e125530"}, - {file = "mkdocstrings_python-1.3.0.tar.gz", hash = "sha256:f967f84bab530fcc13cc9c02eccf0c18bdb2c3bab5c55fa2045938681eec4fc4"}, + {file = "mkdocstrings_python-1.7.0-py3-none-any.whl", hash = "sha256:85c5f009a5a0ebb6076b7818c82a2bb0eebd0b54662628fa8b25ee14a6207951"}, + {file = "mkdocstrings_python-1.7.0.tar.gz", hash = "sha256:5dac2712bd38a3ff0812b8650a68b232601d1474091b380a8b5bc102c8c0d80a"}, ] [package.dependencies] -griffe = ">=0.30,<0.33" +griffe = ">=0.35" mkdocstrings = ">=0.20" [[package]] @@ -1281,19 +1270,22 @@ pylint = ">=1.7" [[package]] name = "pymdown-extensions" -version = "10.1" +version = "10.3" description = "Extension pack for Python Markdown." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pymdown_extensions-10.1-py3-none-any.whl", hash = "sha256:ef25dbbae530e8f67575d222b75ff0649b1e841e22c2ae9a20bad9472c2207dc"}, - {file = "pymdown_extensions-10.1.tar.gz", hash = "sha256:508009b211373058debb8247e168de4cbcb91b1bff7b5e961b2c3e864e00b195"}, + {file = "pymdown_extensions-10.3-py3-none-any.whl", hash = "sha256:77a82c621c58a83efc49a389159181d570e370fff9f810d3a4766a75fc678b66"}, + {file = "pymdown_extensions-10.3.tar.gz", hash = "sha256:94a0d8a03246712b64698af223848fd80aaf1ae4c4be29c8c61939b0467b5722"}, ] [package.dependencies] markdown = ">=3.2" pyyaml = "*" +[package.extras] +extra = ["pygments (>=2.12)"] + [[package]] name = "pyspark" version = "3.4.1" @@ -1578,13 +1570,13 @@ toml = ">=0.10.2,<0.11.0" [[package]] name = "s3transfer" -version = "0.6.1" +version = "0.6.2" description = "An Amazon S3 Transfer Manager" optional = false python-versions = ">= 3.7" files = [ - {file = "s3transfer-0.6.1-py3-none-any.whl", hash = "sha256:3c0da2d074bf35d6870ef157158641178a4204a6e689e82546083e31e0311346"}, - {file = "s3transfer-0.6.1.tar.gz", hash = "sha256:640bb492711f4c0c0905e1f62b6aaeb771881935ad27884852411f8e9cacbca9"}, + {file = "s3transfer-0.6.2-py3-none-any.whl", hash = "sha256:b014be3a8a2aab98cfe1abc7229cc5a9a0cf05eb9c1f2b86b230fd8df3f78084"}, + {file = "s3transfer-0.6.2.tar.gz", hash = "sha256:cab66d3380cca3e70939ef2255d01cd8aece6a4907a9528740f668c4b0611861"}, ] [package.dependencies] @@ -1620,19 +1612,19 @@ yaml = ["pyyaml"] [[package]] name = "setuptools" -version = "68.0.0" +version = "68.2.2" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "setuptools-68.0.0-py3-none-any.whl", hash = "sha256:11e52c67415a381d10d6b462ced9cfb97066179f0e871399e006c4ab101fc85f"}, - {file = "setuptools-68.0.0.tar.gz", hash = "sha256:baf1fdb41c6da4cd2eae722e135500da913332ab3f2f5c7d33af9b492acb5235"}, + {file = "setuptools-68.2.2-py3-none-any.whl", hash = "sha256:b454a35605876da60632df1a60f736524eb73cc47bbc9f3f1ef1b644de74fd2a"}, + {file = "setuptools-68.2.2.tar.gz", hash = "sha256:4ac1475276d2f1c48684874089fefcd83bd7162ddaafb81fac866ba0db282a87"}, ] [package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (==0.8.3)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-hoverxref (<2)", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] +testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] +testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] [[package]] name = "six" @@ -1779,13 +1771,13 @@ test = ["coverage", "flake8 (>=3.7)", "mypy", "pretend", "pytest"] [[package]] name = "virtualenv" -version = "20.24.3" +version = "20.24.5" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.24.3-py3-none-any.whl", hash = "sha256:95a6e9398b4967fbcb5fef2acec5efaf9aa4972049d9ae41f95e0972a683fd02"}, - {file = "virtualenv-20.24.3.tar.gz", hash = "sha256:e5c3b4ce817b0b328af041506a2a299418c98747c4b1e68cb7527e74ced23efc"}, + {file = "virtualenv-20.24.5-py3-none-any.whl", hash = "sha256:b80039f280f4919c77b30f1c23294ae357c4c8701042086e3fc005963e4e537b"}, + {file = "virtualenv-20.24.5.tar.gz", hash = "sha256:e8361967f6da6fbdf1426483bfe9fca8287c242ac0bc30429905721cefbff752"}, ] [package.dependencies] @@ -1794,7 +1786,7 @@ filelock = ">=3.12.2,<4" platformdirs = ">=3.9.1,<4" [package.extras] -docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] [[package]] @@ -1938,4 +1930,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.9" -content-hash = "7f400e526ee55937bbad6a4a141320bd06deac422d5af3e2f032a493a4d972b9" +content-hash = "33662f739dabf3353ed3797c9f9db849ceefc52407de67d24f7e0f8f57e3a622" diff --git a/prospector.yaml b/prospector.yaml index b8853754..9d32f4e9 100644 --- a/prospector.yaml +++ b/prospector.yaml @@ -40,7 +40,7 @@ pylint: pycodestyle: # W293: disabled because we have newlines in docstrings # E203: disabled because pep8 and black disagree on whitespace before colon in some cases - disable: W293,E203,E203 # conflicts with black formatting + disable: W293,E203, E203 # conflicts with black formatting mccabe: disable: diff --git a/pyproject.toml b/pyproject.toml index 15dee832..7c66fdd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,6 @@ packages = [{ include = "spark_expectations" }] python = "^3.8.9" pluggy = "1.0.0" pyspark = "^3.0.0" -delta-spark = "^2.1.0" requests = "^2.28.1" [tool.poetry.group.dev.dependencies] @@ -19,7 +18,6 @@ pytest = "7.3.1" pytest-mock = "3.10.0" coverage = "7.2.5" pyspark = "^3.0.0" -delta-spark = "^2.1.0" mypy = "1.3.0" mkdocs = "1.4.3" prospector = "1.10.0" diff --git a/spark_expectations/config/user_config.py b/spark_expectations/config/user_config.py index 614249f8..77936e94 100644 --- a/spark_expectations/config/user_config.py +++ b/spark_expectations/config/user_config.py @@ -23,15 +23,6 @@ class Constants: "spark.expectations.notifications.slack.webhook_url" ) - se_agg_dq = "agg_dq" - se_source_agg_dq = "source_agg_dq" - se_final_agg_dq = "final_agg_dq" - - se_query_dq = "query_dq" - se_source_query_dq = "source_query_dq" - se_final_query_dq = "final_query_dq" - se_target_table_view = "target_table_view" - se_notifications_on_start = "spark.expectations.notifications.on_start" se_notifications_on_completion = "spark.expectations.notifications.on.completion" se_notifications_on_fail = "spark.expectations.notifications.on.fail" diff --git a/spark_expectations/core/__init__.py b/spark_expectations/core/__init__.py index bbe23262..fb979daf 100644 --- a/spark_expectations/core/__init__.py +++ b/spark_expectations/core/__init__.py @@ -9,12 +9,11 @@ def get_spark_session() -> SparkSession: os.environ.get("UNIT_TESTING_ENV") == "spark_expectations_unit_testing_on_github_actions" ) or (os.environ.get("SPARKEXPECTATIONS_ENV") == "local"): - from delta import configure_spark_with_delta_pip - builder = ( SparkSession.builder.config( "spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension" ) + .config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0") .config( "spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog", @@ -30,6 +29,6 @@ def get_spark_session() -> SparkSession: f"{current_dir}/../../jars/spark-token-provider-kafka-0-10_2.12-3.0.0.jar", ) ) - return configure_spark_with_delta_pip(builder).getOrCreate() + return builder.getOrCreate() - return SparkSession.builder.getOrCreate() + return SparkSession.getActiveSession() diff --git a/spark_expectations/core/context.py b/spark_expectations/core/context.py index 466465de..ecde0bad 100644 --- a/spark_expectations/core/context.py +++ b/spark_expectations/core/context.py @@ -5,8 +5,7 @@ from dataclasses import dataclass from uuid import uuid1 from typing import Dict, Optional, List -from pyspark.sql import DataFrame -from spark_expectations.core import get_spark_session +from pyspark.sql import DataFrame, SparkSession from spark_expectations.config.user_config import Constants as user_config from spark_expectations.core.exceptions import SparkExpectationsMiscException @@ -20,9 +19,9 @@ class SparkExpectationsContext: """ product_id: str + spark: SparkSession def __post_init__(self) -> None: - self.spark = get_spark_session() self._run_id: str = f"{self.product_id}_{uuid1()}" self._run_date: str = self.set_run_date() self._dq_stats_table_name: Optional[str] = None @@ -75,10 +74,8 @@ def __post_init__(self) -> None: self._notification_on_fail: bool = False self._error_drop_threshold: int = 100 - self._cerberus_url: str = "https://prod.cerberus.cloud.com" - self._cerberus_cred_path: str = ( - "app/aplaedaengineering-domo/dq-spark-expectations" - ) + self._cerberus_url: str = "your_cerberus_url" + self._cerberus_cred_path: str = "your_cerberus_sdb_path" self._cerberus_token: Optional[str] = os.environ.get( "DQ_SPARK_EXPECTATIONS_CERBERUS_TOKEN" ) @@ -86,10 +83,6 @@ def __post_init__(self) -> None: self._se_streaming_stats_dict: Dict[str, str] self._enable_se_streaming: bool = False self._se_streaming_secret_env: str = "" - # self._se_streaming_bootstrap_server_url_path: str = "" - # self._se_streaming_secret_path: str = "" - # self._se_streaming_token_endpoint_uri_path: str = "" - # self._se_streaming_client_id_path: str = "" self._debugger_mode: bool = False self._supported_df_query_dq: DataFrame = self.set_supported_df_query_dq() @@ -127,6 +120,9 @@ def __post_init__(self) -> None: self._summarised_row_dq_res: Optional[List[Dict[str, str]]] = None self._rules_error_per: Optional[List[dict]] = None + self._target_and_error_table_writer_config: dict = {} + self._stats_table_writer_config: dict = {} + @property def get_run_id(self) -> str: """ @@ -971,7 +967,7 @@ def set_se_streaming_stats_topic_name( @property def get_se_streaming_stats_topic_name(self) -> str: """ - This function returns nsp topic name + This function returns kafka topic name Returns: str: Returns _se_streaming_stats_topic_name @@ -1469,3 +1465,41 @@ def get_rules_exceeds_threshold(self) -> Optional[List[dict]]: This function returns error percentage for each rule """ return self._rules_error_per + + def set_target_and_error_table_writer_config(self, config: dict) -> None: + """ + This function sets target and error table writer config + Args: + config: dict + Returns: None + + """ + self._target_and_error_table_writer_config = config + + @property + def get_target_and_error_table_writer_config(self) -> dict: + """ + This function returns target and error table writer config + Returns: + dict: Returns target_and_error_table_writer_config which in dict + + """ + return self._target_and_error_table_writer_config + + def set_stats_table_writer_config(self, config: dict) -> None: + """ + This function sets stats table writer config + Args: + config: dict + Returns: None + """ + self._stats_table_writer_config = config + + @property + def get_stats_table_writer_config(self) -> dict: + """ + This function returns stats table writer config + Returns: + dict: Returns stats_table_writer_config which in dict + """ + return self._stats_table_writer_config diff --git a/spark_expectations/core/expectations.py b/spark_expectations/core/expectations.py index 47baa0b8..ff57921d 100644 --- a/spark_expectations/core/expectations.py +++ b/spark_expectations/core/expectations.py @@ -1,12 +1,10 @@ import functools from dataclasses import dataclass from typing import Dict, Optional, Any, Union - -from pyspark.sql import DataFrame - +from pyspark import StorageLevel +from pyspark.sql import DataFrame, SparkSession from spark_expectations import _log from spark_expectations.config.user_config import Constants as user_config -from spark_expectations.core import get_spark_session from spark_expectations.core.context import SparkExpectationsContext from spark_expectations.core.exceptions import ( SparkExpectationsMiscException, @@ -28,102 +26,87 @@ class SparkExpectations: """ This class implements/supports running the data quality rules on a dataframe returned by a function + + Args: + product_id: Name of the product + rules_df: DataFrame which contains the rules. User is responsible for reading + the rules_table in which ever system it is + stats_table: Name of the table where the stats/audit-info need to be written + debugger: Mark it as "True" if the debugger mode need to be enabled, by default is False + stats_streaming_options: Provide options to override the defaults, while writing into the stats streaming table """ product_id: str + rules_df: DataFrame + stats_table: str + target_and_error_table_writer: "WrappedDataFrameWriter" + stats_table_writer: "WrappedDataFrameWriter" debugger: bool = False - stats_streaming_options: Optional[Dict[str, str]] = None + stats_streaming_options: Optional[Dict[str, Union[str, bool]]] = None def __post_init__(self) -> None: - self.spark = get_spark_session() - self.actions = SparkExpectationsActions() - self._context = SparkExpectationsContext(product_id=self.product_id) - - self._writer = SparkExpectationsWriter( - product_id=self.product_id, _context=self._context + if isinstance(self.rules_df, DataFrame): + self.spark: SparkSession = self.rules_df.sparkSession + else: + raise SparkExpectationsMiscException( + "Input rules_df is not of dataframe type" + ) + self.actions: SparkExpectationsActions = SparkExpectationsActions() + self._context: SparkExpectationsContext = SparkExpectationsContext( + product_id=self.product_id, spark=self.spark ) + + self._writer = SparkExpectationsWriter(_context=self._context) self._process = SparkExpectationsRegulateFlow(product_id=self.product_id) self._notification: SparkExpectationsNotify = SparkExpectationsNotify( - product_id=self.product_id, _context=self._context + _context=self._context ) self._statistics_decorator = SparkExpectationsCollectStatistics( - product_id=self.product_id, _context=self._context, _writer=self._writer, ) + self.reader = SparkExpectationsReader(_context=self._context) - self.reader = SparkExpectationsReader( - product_id=self.product_id, - _context=self._context, + self._context.set_target_and_error_table_writer_config( + self.target_and_error_table_writer.build() ) - + self._context.set_stats_table_writer_config(self.stats_table_writer.build()) self._context.set_debugger_mode(self.debugger) + self._context.set_dq_stats_table_name(self.stats_table) + self.rules_df = self.rules_df.persist(StorageLevel.MEMORY_AND_DISK) + # TODO Add target_error_table_writer and stats_table_writer as parameters to this function so this takes precedence + # if user provides it def with_expectations( self, - expectations: dict, + target_table: str, write_to_table: bool = False, write_to_temp_table: bool = False, - row_dq: bool = True, - agg_dq: Optional[Dict[str, bool]] = None, - query_dq: Optional[Dict[str, Union[str, bool]]] = None, - spark_conf: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, str]] = None, - options_error_table: Optional[Dict[str, str]] = None, + user_conf: Optional[Dict[str, Union[str, int, bool]]] = None, + target_table_view: Optional[str] = None, + target_and_error_table_writer: Optional["WrappedDataFrameWriter"] = None, ) -> Any: """ This decorator helps to wrap a function which returns dataframe and apply dataframe rules on it Args: - expectations: Dict of dict's with table and rules as keys + target_table: Name of the table where the final dataframe need to be written write_to_table: Mark it as "True" if the dataframe need to be written as table write_to_temp_table: Mark it as "True" if the input dataframe need to be written to the temp table to break the spark plan - row_dq: Mark it as False to avoid row level expectation, by default is TRUE, - agg_dq: There are several dictionary variables that are used for data quality (DQ) aggregation in both the - source and final DQ layers - agg_dq => Mark it as True to run agg level expectation, by default is False - source_agg_dq => Mark it as True to run source agg level expectation, by default is False - final_agg_dq => Mark it as True to run final agg level expectation, by default is False - query_dq: There are several dictionary variables that are used for data quality (DQ) using query in both - the source and final DQ layers - query_dq => Mark it as True to run query level expectation, by default is False - source_query_dq => Mark it as True to run query dq level expectation, by default is False - final_query_dq => Mark it as True to run query dq level expectation, by default is False - spark_conf: Provide SparkConf to override the defaults, while writing into the table & which also contains - notifications related variables - options: Provide Options to override the defaults, while writing into the table - options_error_table: Provide options to override the defaults, while writing into the error table + user_conf: Provide options to override the defaults, while writing into the stats streaming table + target_table_view: This view is created after the _row_dq process to run the target agg_dq and query_dq. + If value is not provided, defaulted to {target_table}_view + target_and_error_table_writer: Provide the writer to write the target and error table, + this will take precedence over the class level writer Returns: Any: Returns a function which applied the expectations on dataset """ def _except(func: Any) -> Any: - # variable used for enabling source agg dq at different level - _default_agg_dq_dict: Dict[str, bool] = { - user_config.se_agg_dq: False, - user_config.se_source_agg_dq: False, - user_config.se_final_agg_dq: False, - } - _agg_dq_dict: Dict[str, bool] = ( - {**_default_agg_dq_dict, **agg_dq} if agg_dq else _default_agg_dq_dict - ) - - # variable used for enabling query dq at different level - _default_query_dq_dict: Dict[str, Union[str, bool]] = { - user_config.se_query_dq: False, - user_config.se_source_query_dq: False, - user_config.se_final_query_dq: False, - user_config.se_target_table_view: "", - } - _query_dq_dict: Dict[str, Union[str, bool]] = ( - {**_default_query_dq_dict, **query_dq} - if query_dq - else _default_query_dq_dict - ) - # variable used for enabling notification at different level + _default_notification_dict: Dict[str, Union[str, int, bool]] = { user_config.se_notifications_on_start: False, user_config.se_notifications_on_completion: False, @@ -132,68 +115,51 @@ def _except(func: Any) -> Any: user_config.se_notifications_on_error_drop_threshold: 100, } _notification_dict: Dict[str, Union[str, int, bool]] = ( - {**_default_notification_dict, **spark_conf} - if spark_conf + {**_default_notification_dict, **user_conf} + if user_conf else _default_notification_dict ) - _default_stats_streaming_dict: Dict[str, Union[bool, str]] = { user_config.se_enable_streaming: True, user_config.secret_type: "databricks", user_config.dbx_workspace_url: "https://workspace.cloud.databricks.com", - user_config.dbx_secret_scope: "sole_common_prod", + user_config.dbx_secret_scope: "sercet_scope", user_config.dbx_kafka_server_url: "se_streaming_server_url_secret_key", user_config.dbx_secret_token_url: "se_streaming_auth_secret_token_url_key", user_config.dbx_secret_app_name: "se_streaming_auth_secret_appid_key", user_config.dbx_secret_token: "se_streaming_auth_secret_token_key", user_config.dbx_topic_name: "se_streaming_topic_name", } - _se_stats_streaming_dict: Dict[str, Any] = ( {**self.stats_streaming_options} if self.stats_streaming_options else _default_stats_streaming_dict ) - _agg_dq: bool = ( - _agg_dq_dict[user_config.se_agg_dq] - if isinstance(_agg_dq_dict[user_config.se_agg_dq], bool) - else False - ) - _source_agg_dq: bool = ( - _agg_dq_dict[user_config.se_source_agg_dq] - if isinstance(_agg_dq_dict[user_config.se_source_agg_dq], bool) - else False - ) - _final_agg_dq: bool = ( - _agg_dq_dict[user_config.se_final_agg_dq] - if isinstance(_agg_dq_dict[user_config.se_final_agg_dq], bool) - else False - ) + # Overwrite the writers if provided by the user in the with_expectations explicitly + if target_and_error_table_writer: + self._context.set_target_and_error_table_writer_config( + target_and_error_table_writer.build() + ) - _query_dq: bool = ( - bool(_query_dq_dict[user_config.se_query_dq]) - if isinstance(_query_dq_dict[user_config.se_query_dq], bool) - else False + # need to call the get_rules_frm_table function to get the rules from the table as expectations + expectations, rules_execution_settings = self.reader.get_rules_from_df( + self.rules_df, target_table ) - _source_query_dq: bool = ( - bool(_query_dq_dict[user_config.se_source_query_dq]) - if isinstance(_query_dq_dict[user_config.se_source_query_dq], bool) - else False + + _row_dq: bool = rules_execution_settings.get("row_dq", False) + _source_agg_dq: bool = rules_execution_settings.get("source_agg_dq", False) + _target_agg_dq: bool = rules_execution_settings.get("target_agg_dq", False) + _source_query_dq: bool = rules_execution_settings.get( + "source_query_dq", False ) - _final_query_dq: bool = ( - bool(_query_dq_dict[user_config.se_final_query_dq]) - if isinstance(_query_dq_dict[user_config.se_final_query_dq], bool) - else False + _target_query_dq: bool = rules_execution_settings.get( + "target_query_dq", False ) - _target_table_view: str = ( - str(_query_dq_dict[user_config.se_target_table_view]) - if isinstance( - _query_dq_dict[user_config.se_target_table_view], - str, - ) - else "" + target_table_view + if target_table_view + else f"{target_table.split('.')[-1]}_view" ) _notification_on_start: bool = ( @@ -249,7 +215,7 @@ def _except(func: Any) -> Any: else 100 ) - self.reader.set_notification_param(spark_conf) + self.reader.set_notification_param(user_conf) self._context.set_notification_on_start(_notification_on_start) self._context.set_notification_on_completion(_notification_on_completion) self._context.set_notification_on_fail(_notification_on_fail) @@ -271,7 +237,7 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: _error_count: int = 0 _source_dq_df: Optional[DataFrame] = None _source_query_dq_df: Optional[DataFrame] = None - _row_dq_df: Optional[DataFrame] = None + _row_dq_df: DataFrame = _df _final_dq_df: Optional[DataFrame] = None _final_query_dq_df: Optional[DataFrame] = None @@ -315,11 +281,10 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: self.spark.sql(f"drop table if exists {table_name}_temp") _log.info("Dropping to temp table completed") _log.info("Writing to temp table started") - self._writer.write_df_to_table( + self._writer.save_df_as_table( _df, f"{table_name}_temp", - spark_conf=spark_conf, - options=options, + self._context.get_target_and_error_table_writer_config, ) _log.info("Read from temp table started") _df = self.spark.sql(f"select * from {table_name}_temp") @@ -333,11 +298,9 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: expectations=expectations, table_name=table_name, _input_count=_input_count, - spark_conf=spark_conf, - options_error_table=options_error_table, ) - if _agg_dq is True and _source_agg_dq is True: + if _source_agg_dq is True: _log.info( "started processing data quality rules for agg level expectations on soure dataframe" ) @@ -367,7 +330,7 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: "ended processing data quality rules for agg level expectations on source dataframe" ) - if _query_dq is True and _source_query_dq is True: + if _source_query_dq is True: _log.info( "started processing data quality rules for query level expectations on soure dataframe" ) @@ -396,7 +359,7 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: "ended processing data quality rules for query level expectations on source dataframe" ) - if row_dq is True: + if _row_dq is True: _log.info( "started processing data quality rules for row level expectations" ) @@ -415,10 +378,9 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: ) self._context.set_error_count(_error_count) - if _target_table_view: - _row_dq_df.createOrReplaceTempView(_target_table_view) + _row_dq_df.createOrReplaceTempView(_target_table_view) - _output_count = _row_dq_df.count() + _output_count = _row_dq_df.count() if _row_dq_df else 0 self._context.set_output_count(_output_count) self._context.set_row_dq_status(status) @@ -441,7 +403,7 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: "ended processing data quality rules for row level expectations" ) - if row_dq is True and _agg_dq is True and _final_agg_dq is True: + if _row_dq is True and _target_agg_dq is True: _log.info( "started processing data quality rules for agg level expectations on final dataframe" ) @@ -471,11 +433,7 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: "ended processing data quality rules for agg level expectations on final dataframe" ) - if ( - row_dq is True - and _query_dq is True - and _final_query_dq is True - ): + if _row_dq is True and _target_query_dq is True: _log.info( "started processing data quality rules for query level expectations on final dataframe" ) @@ -488,12 +446,7 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: # _: number of error records # status: status of the execution - if _target_table_view and _row_dq_df: - _row_dq_df.createOrReplaceTempView(_target_table_view) - else: - raise SparkExpectationsMiscException( - "final table view name is not supplied to run query dq" - ) + _row_dq_df.createOrReplaceTempView(_target_table_view) ( _final_query_dq_df, @@ -514,13 +467,14 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: "ended processing data quality rules for query level expectations on final dataframe" ) - if row_dq and write_to_table: + # TODO if row_dq is False and source_agg/source_query is True then we need to write the + # dataframe into the target table + if write_to_table: _log.info("Writing into the final table started") - self._writer.write_df_to_table( + self._writer.save_df_as_table( _row_dq_df, f"{table_name}", - spark_conf=spark_conf, - options=options, + self._context.get_target_and_error_table_writer_config, ) _log.info("Writing into the final table ended") @@ -541,3 +495,118 @@ def wrapper(*args: tuple, **kwargs: dict) -> DataFrame: return wrapper return _except + + +class WrappedDataFrameWriter: + """ + A builder pattern class that mimics the functions of PySpark's DataFrameWriter. + + This class allows for chaining methods to set configurations like mode, format, partitioning, + options, and bucketing. It does not require a DataFrame object and is designed purely to + collect and return configurations. + + Example usage: + -------------- + writer = WrappedDataFrameWriter().mode("overwrite")\ + .format("parquet")\ + .partitionBy("date", "region")\ + .option("compression", "gzip")\ + .options(path="/path/to/output", inferSchema="true")\ + .bucketBy(4, "country", "city") + + config = writer.build() + print(config) + + Attributes: + ----------- + _mode : str + The mode for writing (e.g., "overwrite", "append"). + _format : str + The format for writing (e.g., "parquet", "csv"). + _partition_by : list + Columns by which the data should be partitioned. + _options : dict + Additional options for writing. + _bucket_by : dict + Configuration for bucketing, including number of buckets and columns. + """ + + def __init__(self) -> None: + self._mode: Optional[str] = None + self._format: Optional[str] = None + self._partition_by: list = [] + self._options: dict[str, str] = {} + self._bucket_by: Dict[str, Union[int, tuple]] = {} + self._sort_by: list = [] + + def mode(self, saveMode: str) -> "WrappedDataFrameWriter": # noqa: N803 + """Set the mode for writing.""" + self._mode = saveMode + return self + + def format(self, source: str) -> "WrappedDataFrameWriter": + """Set the format for writing.""" + self._format = source + return self + + def partitionBy(self, *columns: str) -> "WrappedDataFrameWriter": # noqa: N802 + """Set the columns by which the data should be partitioned.""" + self._partition_by.extend(columns) + return self + + def option(self, key: str, value: str) -> "WrappedDataFrameWriter": + """Set a single option for writing.""" + self._options[key] = value + return self + + def options(self, **options: str) -> "WrappedDataFrameWriter": + """Set multiple options for writing.""" + self._options.update(options) + return self + + def bucketBy( # noqa: N802 + self, num_buckets: int, *columns: str + ) -> "WrappedDataFrameWriter": + """Set the configuration for bucketing.""" + self._bucket_by["num_buckets"] = num_buckets + self._bucket_by["columns"] = columns + return self + + def sortBy(self, *columns: str) -> "WrappedDataFrameWriter": # noqa: N802 + """Set the configuration for bucketing.""" + self._sort_by.extend(columns) + return self + + def build(self) -> Dict[str, Union[str, list, dict, tuple, int, None]]: + """Return the collected configurations.""" + if self._format is not None and self._format.lower() == "delta": + if self._bucket_by is not None and self._bucket_by: + raise SparkExpectationsMiscException( + "Bucketing is not supported for delta tables yet..." + ) + + return { + "mode": self._mode, + "format": self._format, + "partitionBy": self._partition_by, + "options": self._options, + "bucketBy": self._bucket_by, + "sortBy": self._sort_by, + } + + # config = {} + # + # if cls._mode: + # config["mode"] = cls._mode + # if cls._format: + # config["format"] = cls._format + # if cls._partition_by: + # config["partitionBy"] = cls._partition_by + # if cls._options: + # config["options"] = cls._options + # if cls._bucket_by: + # config["bucketBy"] = cls._bucket_by + # if cls._sort_by: + # config["sortBy"] = cls._sort_by + # + # return config diff --git a/spark_expectations/examples/base_setup.py b/spark_expectations/examples/base_setup.py index db4f65b0..8b873bc6 100644 --- a/spark_expectations/examples/base_setup.py +++ b/spark_expectations/examples/base_setup.py @@ -1,37 +1,11 @@ import os -import subprocess +from pyspark.sql.session import SparkSession -# setting up env for local os.environ["SPARKEXPECTATIONS_ENV"] = "local" -from spark_expectations.core import get_spark_session +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -spark = get_spark_session() - - -def main() -> None: - os.environ["DQ_SPARK_EXPECTATIONS_CERBERUS_TOKEN"] = "" - current_dir = os.path.dirname(os.path.abspath(__file__)) - - print("Creating the necessary infrastructure for the tests to run locally!") - - # run kafka locally in docker - print("create or run if exist docker container") - os.system(f"sh {current_dir}/docker_scripts/docker_kafka_start_script.sh") - - # create database - os.system("rm -rf /tmp/hive/warehouse/dq_spark_local.db") - spark.sql("create database if not exists dq_spark_local") - spark.sql("use dq_spark_local") - - # create project_rules_table - spark.sql("drop table if exists dq_rules") - os.system("rm -rf /tmp/hive/warehouse/dq_spark_local.db/dq_rules") - - spark.sql( - """ - create table dq_rules ( - product_id STRING, +RULES_TABLE_SCHEMA = """ ( product_id STRING, table_name STRING, rule_type STRING, rule STRING, @@ -44,59 +18,15 @@ def main() -> None: enable_for_target_dq_validation BOOLEAN, is_active BOOLEAN, enable_error_drop_alert BOOLEAN, - error_drop_threshold INT - ) - USING delta - """ - ) - - spark.sql( - "ALTER TABLE dq_rules ADD CONSTRAINT rule_type_action CHECK (rule_type in ('row_dq', 'agg_dq', 'query_dq'));" - ) - - spark.sql( - "ALTER TABLE dq_rules ADD CONSTRAINT action CHECK ((rule_type = 'row_dq' and action_if_failed IN ('ignore', 'drop', 'fail')) or " - "(rule_type = 'agg_dq' and action_if_failed in ('ignore', 'fail')) or (rule_type = 'query_dq' and action_if_failed in ('ignore', 'fail')));" - ) - - # create project_dq_stats_table - # spark.sql("drop table if exists dq_stats") - # os.system("rm -rf /tmp/hive/warehouse/dq_spark_local.db/dq_stats") - # spark.sql( - # """ - # create table dq_stats ( - # product_id STRING, - # table_name STRING, - # input_count LONG, - # error_count LONG, - # output_count LONG, - # output_percentage FLOAT, - # success_percentage FLOAT, - # error_percentage FLOAT, - # source_agg_dq_results array>, - # final_agg_dq_results array>, - # source_query_dq_results array>, - # final_query_dq_results array>, - # row_dq_res_summary array>, - # dq_status map, - # dq_run_time map, - # dq_rules map>, - # meta_dq_run_id STRING, - # meta_dq_run_date DATE, - # meta_dq_run_datetime TIMESTAMP - # ) - # USING delta - # """ - # ) - - spark.sql( - """ - insert into table dq_rules values - ("your_product", "dq_spark_local.customer_order", "row_dq", "customer_id_is_not_null", "customer_id", "customer_id is not null","drop", "validity", "customer_id ishould not be null", true, true, true, false, 0) - ,("your_product", "dq_spark_local.customer_order", "row_dq", "sales_greater_than_zero", "sales", "sales > 0", "drop", "accuracy", "sales value should be greater than zero", true, true, true, false, 0) - ,("your_product", "dq_spark_local.customer_order", "row_dq", "discount_threshold", "discount", "discount*100 < 60","drop", "validity", "discount should be less than 40", true, true, true, false, 0) - ,("your_product", "dq_spark_local.customer_order", "row_dq", "ship_mode_in_set", "ship_mode", "lower(trim(ship_mode)) in('second class', 'standard class', 'standard class')", "drop", "validity", "ship_mode mode belongs in the sets", true, true, true, false, 0) - ,("your_product", "dq_spark_local.customer_order", "row_dq", "profit_threshold", "profit", "profit>0", "drop", "validity", "profit threshold should be greater tahn 0", true, true, true, true, 0) + error_drop_threshold INT ) +""" + +RULES_DATA = """ + ("your_product", "dq_spark_local.customer_order", "row_dq", "customer_id_is_not_null", "customer_id", "customer_id is not null","drop", "validity", "customer_id ishould not be null", true, true,false, false, 0) + ,("your_product", "dq_spark_local.customer_order", "row_dq", "sales_greater_than_zero", "sales", "sales > 2", "drop", "accuracy", "sales value should be greater than zero", true, true, true, false, 0) + ,("your_product", "dq_spark_local.customer_order", "row_dq", "discount_threshold", "discount", "discount*100 < 60","drop", "validity", "discount should be less than 40", true, true, false, false, 0) + ,("your_product", "dq_spark_local.customer_order", "row_dq", "ship_mode_in_set", "ship_mode", "lower(trim(ship_mode)) in('second class', 'standard class', 'standard class')", "drop", "validity", "ship_mode mode belongs in the sets", true, true, false, false, 0) + ,("your_product", "dq_spark_local.customer_order", "row_dq", "profit_threshold", "profit", "profit>0", "drop", "validity", "profit threshold should be greater tahn 0", true, true, false, true, 0) ,("your_product", "dq_spark_local.customer_order", "agg_dq", "sum_of_sales", "sales", "sum(sales)>10000", "ignore", "validity", "regex format validation for quantity", true, true, true, false, 0) ,("your_product", "dq_spark_local.customer_order", "agg_dq", "sum_of_quantity", "quantity", "sum(sales)>10000", "ignore", "validity", "regex format validation for quantity", true, true, true, false, 0) @@ -106,61 +36,112 @@ def main() -> None: ,("your_product", "dq_spark_local.customer_order", "query_dq", "product_missing_count_threshold", "*", "((select count(distinct product_id) from product) - (select count(distinct product_id) from order))>(select count(distinct product_id) from product)*0.2", "ignore", "validity", "row count threshold", true, true, true, false, 0) ,("your_product", "dq_spark_local.customer_order", "query_dq", "product_category", "*", "(select count(distinct category) from product) < 5", "ignore", "validity", "distinct product category", true, true, true, false, 0) ,("your_product", "dq_spark_local.customer_order", "query_dq", "row_count_in_order", "*", "(select count(*) from order)<10000", "ignore", "accuracy", "count of the row in order dataset", true, true, true, false, 0) - """ - ) + +""" - # , ("your_product", "dq_spark_local.customer_order", "row_dq", "referential_integrity_customer_id", "customer_id", - # "customer_id in(select distinct customer_id from customer)", true, true, "drop", true, "validity", - # "referential integrity for cuatomer_id") - # , ("your_product", "dq_spark_local.customer_order", "row_dq", "referential_integrity_product_id", "product_id", - # "select count(*) from (select distinct product_id as ref_product from product) where product_id=ref_product > 1", - # true, true, "drop", true, "validity", "referntial integrity for product_id") - # , ( - # "your_product", "dq_spark_local.customer_order", "row_dq", "regex_format_sales", "sales", "sales rlike '[1-9]+.[1-9]+'", - # true, true, "drop", true, "validity", "regex format validation for sales") - # , ("your_product", "dq_spark_local.customer_order", "row_dq", "regex_format_quantity", "quantity", - # "quantity rlike '[1-9]+.[1-9]+'", true, true, "drop", true, "validity", "regex format validation for quantity") - # , ("your_product", "dq_spark_local.customer_order", "row_dq", "date_format_order_date", "order_date", - # "order_date rlike '([1-3][1-9]|[0-1])/([1-2]|[1-9])/20[0-2][0-9]''", true, true, "drop", true, "validity", - # "regex format validation for quantity") - # , ("your_product", "dq_spark_local.customer_order", "row_dq", "regex_format_order_id", "order_id", - # "order_id rlike '(US|CA)-20[0-2][0-9]-*''", true, true, "drop", true, "validity", - # "regex format validation for quantity") - - # , ("your_product", "dq_spark_local.employee_new", "query_dq", "", "*", - # "(select count(*) from dq_spark_local_employee_new)!=(select count(*) from dq_spark_local_employee_new)", true, - # false, "ignore", false, "validity", "canary check to comapre the two table count") - # , ("your_product", "dq_spark_local.employee_new", "query_dq", "department_salary_threshold", "department", - # "(select count(*) from (select department from dq_spark_local_employee_new group by department having sum(bonus)>1000))<1", - # true, false, "ignore", true, "validity", "each sub-department threshold") - # , ( - # "your_product", "dq_spark_local.employee_new", "query_dq", "count_of_exit_date_nulls_threshold", "exit_date", "", true, - # true, "ignore", false, "validity", "exit_date null threshold") - - # , ("your_product", "dq_spark_local.customer_order", "row_dq", "complete_duplicate", "*", - # "count(*) over(partition by customer_id,product_id,order_id,order_date,ship_date,ship_mode,sales,quantity,discount,profit order by 1)", - # true, true, "drop", true, "validity", "complete duplicate record") - # , ("your_product", "dq_spark_local.customer_order", "row_dq", "primary_key_check", "*", - # "count(*) over(partition by customer_id, order_id order by 1)", true, true, "drop", true, "validity", - # "primary key check") - - # ,("your_product", "dq_spark_local.customer_order", "row_dq", "order_date_format_check", "order_date", "to_date(order_date, 'dd/MM/yyyy')", true, true,"drop" ,true, "validity", "Age of the employee should be less than 65") - spark.sql("select * from dq_rules").show(truncate=False) +def set_up_kafka() -> None: + print("create or run if exist docker container") + os.system(f"sh {CURRENT_DIR}/docker_scripts/docker_kafka_start_script.sh") + - # DROP the data tables and error tables - spark.sql("drop table if exists dq_spark_local.customer_order") - os.system( - "rm -rf /tmp/hive/warehouse/dq_spark_local.db/dq_spark_local.customer_order" +def add_kafka_jars(builder: SparkSession.builder) -> SparkSession.builder: + return builder.config( # below jars are used only in the local env, not coupled with databricks or EMR + "spark.jars", + f"{CURRENT_DIR}/../../jars/spark-sql-kafka-0-10_2.12-3.0.0.jar," + f"{CURRENT_DIR}/../../jars/kafka-clients-3.0.0.jar," + f"{CURRENT_DIR}/../../jars/commons-pool2-2.8.0.jar," + f"{CURRENT_DIR}/../../jars/spark-token-provider-kafka-0-10_2.12-3.0.0.jar", ) - spark.sql("drop table if exists dq_spark_local.customer_order_error") - os.system( - "rm -rf /tmp/hive/warehouse/dq_spark_local.db/dq_spark_local.customer_order_error" + +def set_up_iceberg() -> SparkSession: + set_up_kafka() + spark = add_kafka_jars( + SparkSession.builder.config( + "spark.jars.packages", + "org.apache.iceberg:iceberg-spark-runtime-3.4_2.12:1.3.1", + ) + .config( + "spark.sql.extensions", + "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", + ) + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.iceberg.spark.SparkSessionCatalog", + ) + .config("spark.sql.catalog.spark_catalog.type", "hadoop") + .config("spark.sql.catalog.spark_catalog.warehouse", "/tmp/hive/warehouse") + .config("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") + .config("spark.sql.catalog.local.type", "hadoop") + .config("spark.sql.catalog.local.warehouse", "/tmp/hive/warehouse") + ).getOrCreate() + + os.system("rm -rf /tmp/hive/warehouse/dq_spark_local") + + spark.sql("create database if not exists spark_catalog.dq_spark_local") + spark.sql(" use spark_catalog.dq_spark_local") + + spark.sql("drop table if exists dq_spark_local.dq_stats") + + spark.sql("drop table if exists dq_spark_local.dq_rules") + + spark.sql( + f" CREATE TABLE dq_spark_local.dq_rules {RULES_TABLE_SCHEMA} USING ICEBERG" ) + spark.sql(f" INSERT INTO dq_spark_local.dq_rules values {RULES_DATA} ") + + spark.sql("select * from dq_spark_local.dq_rules").show(truncate=False) + return spark + + +def set_up_bigquery(materialization_dataset: str) -> SparkSession: + set_up_kafka() + spark = add_kafka_jars( + SparkSession.builder.config( + "spark.jars.packages", + "com.google.cloud.spark:spark-bigquery-with-dependencies_2.12:0.30.0", + ) + ).getOrCreate() + spark._jsc.hadoopConfiguration().set( + "fs.gs.impl", "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem" + ) + spark.conf.set("viewsEnabled", "true") + spark.conf.set("materializationDataset", materialization_dataset) + + # Add dependencies like gcs-connector-hadoop3-2.2.6-SNAPSHOT-shaded.jar, spark-avro_2.12-3.4.1.jar if you wanted to use indirect method for reading/writing + return spark + + +def set_up_delta() -> SparkSession: + set_up_kafka() + + builder = add_kafka_jars( + SparkSession.builder.config( + "spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension" + ) + .config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + .config("spark.sql.warehouse.dir", "/tmp/hive/warehouse") + .config("spark.driver.extraJavaOptions", "-Dderby.system.home=/tmp/derby") + .config("spark.jars.ivy", "/tmp/ivy2") + ) + spark = builder.getOrCreate() + + os.system("rm -rf /tmp/hive/warehouse/dq_spark_local.db") + + spark.sql("create database if not exists dq_spark_local") + spark.sql("use dq_spark_local") - print("Local infrastructure setup is done") + spark.sql("drop table if exists dq_stats") + spark.sql("drop table if exists dq_rules") -if __name__ == "__main__": - main() + spark.sql(f" CREATE TABLE dq_rules {RULES_TABLE_SCHEMA} USING DELTA") + spark.sql(f" INSERT INTO dq_rules values {RULES_DATA}") + + spark.sql("select * from dq_rules").show(truncate=False) + return spark diff --git a/spark_expectations/examples/docker_scripts/docker_kafka_start_script.sh b/spark_expectations/examples/docker_scripts/docker_kafka_start_script.sh index 6d0e3b2e..906f7c71 100644 --- a/spark_expectations/examples/docker_scripts/docker_kafka_start_script.sh +++ b/spark_expectations/examples/docker_scripts/docker_kafka_start_script.sh @@ -3,7 +3,7 @@ file_dir=$(dirname "$0") docker_container_name="spark_expectations_kafka_docker" -docker_image_name="spark_expectations_nsp_topic" +docker_image_name="spark_expectations_kafka_topic" if [[ $(docker ps -a | grep "$docker_container_name") ]]; then if [[ $(docker ps | grep "$docker_container_name") ]]; then diff --git a/spark_expectations/examples/docker_scripts/docker_kafka_stop_script.sh b/spark_expectations/examples/docker_scripts/docker_kafka_stop_script.sh index 62dcf84e..698935d0 100644 --- a/spark_expectations/examples/docker_scripts/docker_kafka_stop_script.sh +++ b/spark_expectations/examples/docker_scripts/docker_kafka_stop_script.sh @@ -2,7 +2,7 @@ #shell scripts to remove docker container docker_container_name="spark_expectations_kafka_docker" -docker_image_name="spark_expectations_nsp_topic" +docker_image_name="spark_expectations_kafka_topic" if [[ $(docker ps -a | grep "$docker_container_name") ]]; then docker rm -f "$docker_container_name" diff --git a/spark_expectations/examples/sample_dq_bigquery.py b/spark_expectations/examples/sample_dq_bigquery.py new file mode 100644 index 00000000..e55417a5 --- /dev/null +++ b/spark_expectations/examples/sample_dq_bigquery.py @@ -0,0 +1,120 @@ +# mypy: ignore-errors +import os +from pyspark.sql import DataFrame +from spark_expectations import _log +from spark_expectations.core.expectations import ( + SparkExpectations, + WrappedDataFrameWriter, +) +from spark_expectations.config.user_config import Constants as user_config +from spark_expectations.examples.base_setup import set_up_bigquery + +os.environ[ + "GOOGLE_APPLICATION_CREDENTIALS" +] = "path_to_your_json_credential_file" # This is needed for spark write to bigquery +writer = ( + WrappedDataFrameWriter() + .mode("overwrite") + .format("bigquery") + .option("createDisposition", "CREATE_IF_NEEDED") + .option("writeMethod", "direct") +) + +# if wanted to use indirect method use below setting and spark session +# writer = WrappedDataFrameWriter().mode("overwrite").format("bigquery").\ +# option("createDisposition", "CREATE_IF_NEEDED")\ +# .option("writeMethod", "indirect")\ +# .option("intermediateFormat", "AVRO")\ +# .option("temporaryGcsBucket", "") + + +# pass materialization dataset +spark = set_up_bigquery("") + +se: SparkExpectations = SparkExpectations( + product_id="your_product", + rules_df=spark.read.format("bigquery").load( + ".." + ), + stats_table="..", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, +) + +user_conf = { + user_config.se_notifications_enable_email: False, + user_config.se_notifications_email_smtp_host: "mailhost.com", + user_config.se_notifications_email_smtp_port: 25, + user_config.se_notifications_email_from: "", + user_config.se_notifications_email_to_other_mail_id: "", + user_config.se_notifications_email_subject: "spark expectations - data quality - notifications", + user_config.se_notifications_enable_slack: False, + user_config.se_notifications_slack_webhook_url: "", + user_config.se_notifications_on_start: True, + user_config.se_notifications_on_completion: True, + user_config.se_notifications_on_fail: True, + user_config.se_notifications_on_error_drop_exceeds_threshold_breach: True, + user_config.se_notifications_on_error_drop_threshold: 15, +} + + +@se.with_expectations( + target_table="..", + write_to_table=True, + user_conf=user_conf, + target_table_view="..", +) +def build_new() -> DataFrame: + _df_order: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) + ) + _df_order.createOrReplaceTempView("order") + + _df_product: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(os.path.dirname(__file__), "resources/product.csv")) + ) + _df_product.createOrReplaceTempView("product") + + _df_customer: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(os.path.dirname(__file__), "resources/customer.csv")) + ) + + _df_customer.createOrReplaceTempView("customer") + + return _df_order + + +if __name__ == "__main__": + build_new() + + spark.sql("select * from dq_spark_local.dq_stats").show(truncate=False) + spark.sql("select * from dq_spark_local.dq_stats").printSchema() + spark.sql("select * from dq_spark_local.customer_order").show(truncate=False) + spark.sql("select count(*) from dq_spark_local.customer_order_error").show( + truncate=False + ) + + _log.info("stats data in the kafka topic") + # display posted statistics from the kafka topic + spark.read.format("kafka").option( + "kafka.bootstrap.servers", "localhost:9092" + ).option("subscribe", "dq-sparkexpectations-stats").option( + "startingOffsets", "earliest" + ).option( + "endingOffsets", "latest" + ).load().selectExpr( + "cast(value as string) as stats_records" + ).show( + truncate=False + ) + + # remove docker container + current_dir = os.path.dirname(os.path.abspath(__file__)) + os.system(f"sh {current_dir}/docker_scripts/docker_kafka_stop_script.sh") diff --git a/spark_expectations/examples/sample_dq.py b/spark_expectations/examples/sample_dq_delta.py similarity index 71% rename from spark_expectations/examples/sample_dq.py rename to spark_expectations/examples/sample_dq_delta.py index fd5ff2fd..636a6d4f 100644 --- a/spark_expectations/examples/sample_dq.py +++ b/spark_expectations/examples/sample_dq_delta.py @@ -1,17 +1,32 @@ +# mypy: ignore-errors import os + from pyspark.sql import DataFrame from spark_expectations import _log -from spark_expectations.examples.base_setup import main -from spark_expectations.core import get_spark_session -from spark_expectations.core.expectations import SparkExpectations +from spark_expectations.examples.base_setup import set_up_delta +from spark_expectations.core.expectations import ( + SparkExpectations, + WrappedDataFrameWriter, +) from spark_expectations.config.user_config import Constants as user_config -main() -se: SparkExpectations = SparkExpectations(product_id="your_product", debugger=False) -spark = get_spark_session() +writer = WrappedDataFrameWriter().mode("append").format("delta") + +spark = set_up_delta() + +se: SparkExpectations = SparkExpectations( + product_id="your_product", + rules_df=spark.table("dq_spark_local.dq_rules"), + stats_table="dq_spark_local.dq_stats", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + # stats_streaming_options={user_config.se_enable_streaming: False}, +) + -global_spark_Conf = { +user_conf = { user_config.se_notifications_enable_email: False, user_config.se_notifications_email_smtp_host: "mailhost.com", user_config.se_notifications_email_smtp_port: 25, @@ -29,25 +44,10 @@ @se.with_expectations( - se.reader.get_rules_from_table( - product_rules_table="dq_spark_local.dq_rules", - target_table_name="dq_spark_local.customer_order", - dq_stats_table_name="dq_spark_local.dq_stats", - ), + target_table="dq_spark_local.customer_order", write_to_table=True, - row_dq=True, - agg_dq={ - user_config.se_agg_dq: True, - user_config.se_source_agg_dq: True, - user_config.se_final_agg_dq: True, - }, - query_dq={ - user_config.se_query_dq: True, - user_config.se_source_query_dq: True, - user_config.se_final_query_dq: True, - user_config.se_target_table_view: "order", - }, - spark_conf=global_spark_Conf, + user_conf=user_conf, + target_table_view="order", ) def build_new() -> DataFrame: _df_order: DataFrame = ( @@ -78,10 +78,15 @@ def build_new() -> DataFrame: if __name__ == "__main__": build_new() + spark.sql("use dq_spark_local") spark.sql("select * from dq_spark_local.dq_stats").show(truncate=False) spark.sql("select * from dq_spark_local.dq_stats").printSchema() + spark.sql("select * from dq_spark_local.customer_order").show(truncate=False) + spark.sql("select count(*) from dq_spark_local.customer_order_error").show( + truncate=False + ) - _log.info("stats data in the nsp topic") + _log.info("stats data in the kafka topic") # display posted statistics from the kafka topic spark.read.format("kafka").option( "kafka.bootstrap.servers", "localhost:9092" diff --git a/spark_expectations/examples/sample_dq_iceberg.py b/spark_expectations/examples/sample_dq_iceberg.py new file mode 100644 index 00000000..a9128c37 --- /dev/null +++ b/spark_expectations/examples/sample_dq_iceberg.py @@ -0,0 +1,103 @@ +# mypy: ignore-errors +import os + +from pyspark.sql import DataFrame +from spark_expectations import _log +from spark_expectations.examples.base_setup import set_up_iceberg +from spark_expectations.core.expectations import ( + SparkExpectations, + WrappedDataFrameWriter, +) +from spark_expectations.config.user_config import Constants as user_config + +writer = WrappedDataFrameWriter().mode("append").format("iceberg") + +spark = set_up_iceberg() + +print(spark.sparkContext.getConf().getAll()) +se: SparkExpectations = SparkExpectations( + product_id="your_product", + rules_df=spark.sql("select * from dq_spark_local.dq_rules"), + stats_table="dq_spark_local.dq_stats", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + stats_streaming_options={user_config.se_enable_streaming: False}, +) + +user_conf = { + user_config.se_notifications_enable_email: False, + user_config.se_notifications_email_smtp_host: "mailhost.com", + user_config.se_notifications_email_smtp_port: 25, + user_config.se_notifications_email_from: "", + user_config.se_notifications_email_to_other_mail_id: "", + user_config.se_notifications_email_subject: "spark expectations - data quality - notifications", + user_config.se_notifications_enable_slack: False, + user_config.se_notifications_slack_webhook_url: "", + user_config.se_notifications_on_start: True, + user_config.se_notifications_on_completion: True, + user_config.se_notifications_on_fail: True, + user_config.se_notifications_on_error_drop_exceeds_threshold_breach: True, + user_config.se_notifications_on_error_drop_threshold: 15, +} + + +@se.with_expectations( + target_table="dq_spark_local.customer_order", + write_to_table=True, + user_conf=user_conf, + target_table_view="order", +) +def build_new() -> DataFrame: + _df_order: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(os.path.dirname(__file__), "resources/order.csv")) + ) + _df_order.createOrReplaceTempView("order") + + _df_product: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(os.path.dirname(__file__), "resources/product.csv")) + ) + _df_product.createOrReplaceTempView("product") + + _df_customer: DataFrame = ( + spark.read.option("header", "true") + .option("inferSchema", "true") + .csv(os.path.join(os.path.dirname(__file__), "resources/customer.csv")) + ) + + _df_customer.createOrReplaceTempView("customer") + + return _df_order + + +if __name__ == "__main__": + build_new() + + spark.sql("select * from dq_spark_local.dq_stats").show(truncate=False) + spark.sql("select * from dq_spark_local.dq_stats").printSchema() + spark.sql("select * from dq_spark_local.customer_order").show(truncate=False) + spark.sql("select count(*) from dq_spark_local.customer_order_error").show( + truncate=False + ) + + _log.info("stats data in the kafka topic") + # display posted statistics from the kafka topic + spark.read.format("kafka").option( + "kafka.bootstrap.servers", "localhost:9092" + ).option("subscribe", "dq-sparkexpectations-stats").option( + "startingOffsets", "earliest" + ).option( + "endingOffsets", "latest" + ).load().selectExpr( + "cast(value as string) as stats_records" + ).show( + truncate=False + ) + + # remove docker container + current_dir = os.path.dirname(os.path.abspath(__file__)) + os.system(f"sh {current_dir}/docker_scripts/docker_kafka_stop_script.sh") diff --git a/spark_expectations/notifications/push/spark_expectations_notify.py b/spark_expectations/notifications/push/spark_expectations_notify.py index 4659b0cc..d17ac0c6 100644 --- a/spark_expectations/notifications/push/spark_expectations_notify.py +++ b/spark_expectations/notifications/push/spark_expectations_notify.py @@ -12,7 +12,6 @@ class SparkExpectationsNotify: This class implements Notification """ - product_id: str _context: SparkExpectationsContext def __post_init__(self) -> None: @@ -89,7 +88,7 @@ def notify_on_exceeds_of_error_threshold(self) -> None: _notification_message = ( f"Spark expectations - dropped error percentage has been exceeded above the threshold " f"value({self._context.get_error_drop_threshold}%) for `row_data` quality validation \n\n" - f"product_id: {self.product_id}\n" + f"product_id: {self._context.product_id}\n" f"table_name: {self._context.get_table_name}\n" f"run_id: {self._context.get_run_id}\n" f"run_date: {self._context.get_run_date}\n" @@ -119,7 +118,7 @@ def notify_on_completion(self) -> None: _notification_message = ( "Spark expectations job has been completed \n\n" - f"product_id: {self.product_id}\n" + f"product_id: {self._context.product_id}\n" f"table_name: {self._context.get_table_name}\n" f"run_id: {self._context.get_run_id}\n" f"run_date: {self._context.get_run_date}\n" @@ -151,7 +150,7 @@ def notify_on_failure(self, _error: str) -> None: _notification_message = ( "Spark expectations job has been failed \n\n" - f"product_id: {self.product_id}\n" + f"product_id: {self._context.product_id}\n" f"table_name: {self._context.get_table_name}\n" f"run_id: {self._context.get_run_id}\n" f"run_date: {self._context.get_run_date}\n" @@ -193,7 +192,7 @@ def construct_message_for_each_rules( _notification_message = ( f"{rule_name} has been exceeded above the threshold " f"value({set_error_drop_threshold}%) for `row_data` quality validation\n" - f"product_id: {self.product_id}\n" + f"product_id: {self._context.product_id}\n" f"table_name: {self._context.get_table_name}\n" f"run_id: {self._context.get_run_id}\n" f"run_date: {self._context.get_run_date}\n" @@ -213,9 +212,7 @@ def notify_on_exceeds_of_error_threshold_each_rules( """ This function sends notification when specific rule error drop percentage exceeds above threshold Args: - rule_name: name of the dq rule - failed_row_count: number of failed of dq rule - error_drop_percentage: error drop percentage + message: message to be sent in notification Returns: None """ @@ -255,7 +252,11 @@ def notify_rules_exceeds_threshold(self, rules: dict) -> None: rule_name = rule["rule"] rule_action = rule["action_if_failed"] - failed_row_count = int(rules_failed_row_count[rule_name]) + failed_row_count = int( + rules_failed_row_count[rule_name] + if rule_name in rules_failed_row_count + else 0 + ) if failed_row_count is not None and failed_row_count > 0: set_error_drop_threshold = int(rule["error_drop_threshold"]) diff --git a/spark_expectations/secrets/__init__.py b/spark_expectations/secrets/__init__.py index 9dc71b4c..849713dd 100644 --- a/spark_expectations/secrets/__init__.py +++ b/spark_expectations/secrets/__init__.py @@ -4,9 +4,9 @@ from dataclasses import dataclass from typing import Optional, Dict import pluggy -from spark_expectations.config.user_config import Constants as UserConfig -from spark_expectations.core import get_spark_session +from pyspark.sql.session import SparkSession from spark_expectations import _log +from spark_expectations.config.user_config import Constants as UserConfig SPARK_EXPECTATIONS_SECRETS_BACKEND = "spark_expectations_secrets_backend" @@ -92,8 +92,7 @@ def get_secret_value( try: from pyspark.dbutils import DBUtils - spark = get_spark_session() # pragma: no cover - + spark = SparkSession.getActiveSession() # pragma: no cover dbutils = DBUtils(spark) # pragma: no cover except ImportError: raise ImportError( diff --git a/spark_expectations/sinks/__init__.py b/spark_expectations/sinks/__init__.py index 74a3585e..d50b8c60 100644 --- a/spark_expectations/sinks/__init__.py +++ b/spark_expectations/sinks/__init__.py @@ -6,9 +6,6 @@ SPARK_EXPECTATIONS_WRITER_PLUGIN, ) -from spark_expectations.sinks.plugins.delta_writer import ( - SparkExpectationsDeltaWritePluginImpl, -) from spark_expectations.sinks.plugins.kafka_writer import ( SparkExpectationsKafkaWritePluginImpl, ) @@ -24,9 +21,6 @@ def get_sink_hook() -> pluggy.PluginManager: """ pm = pluggy.PluginManager(SPARK_EXPECTATIONS_WRITER_PLUGIN) pm.add_hookspecs(SparkExpectationsSinkWriter) - pm.register( - SparkExpectationsDeltaWritePluginImpl(), "spark_expectations_delta_write" - ) pm.register( SparkExpectationsKafkaWritePluginImpl(), "spark_expectations_kafka_write" ) diff --git a/spark_expectations/sinks/plugins/base_writer.py b/spark_expectations/sinks/plugins/base_writer.py index 56fd0c22..86758d9d 100644 --- a/spark_expectations/sinks/plugins/base_writer.py +++ b/spark_expectations/sinks/plugins/base_writer.py @@ -14,7 +14,7 @@ def writer( self, _write_args: Dict[Union[str], Union[str, bool, Dict[str, str], DataFrame]] ) -> None: """ - function consist signature to write data into delta/nsp, which will be implemented in the child class + function consist signature to write data into kafka etc. which will be implemented in the child class Args: _write_args: diff --git a/spark_expectations/sinks/plugins/delta_writer.py b/spark_expectations/sinks/plugins/delta_writer.py deleted file mode 100644 index 85b97051..00000000 --- a/spark_expectations/sinks/plugins/delta_writer.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import Dict, Union -from pyspark.sql import DataFrame -from spark_expectations import _log -from spark_expectations.sinks.plugins.base_writer import ( - SparkExpectationsSinkWriter, - spark_expectations_writer_impl, -) -from spark_expectations.core import get_spark_session -from spark_expectations.core.exceptions import SparkExpectationsMiscException - - -class SparkExpectationsDeltaWritePluginImpl(SparkExpectationsSinkWriter): - """ - function implements/supports data into the delta table - """ - - @spark_expectations_writer_impl - def writer( - self, _write_args: Dict[Union[str], Union[str, bool, Dict[str, str], DataFrame]] - ) -> None: - """ - Args: - _write_args: - - Returns: - """ - try: - _log.info("started writing data into delta stats table") - df: DataFrame = _write_args.get("stats_df") - df.write.saveAsTable( - name=f"{_write_args.get('table_name')}", - **{"mode": "append", "format": "delta", "mergeSchema": "true"}, - ) - get_spark_session().sql( - f"ALTER TABLE {_write_args.get('table_name')} " - f"SET TBLPROPERTIES ('product_id' = '{_write_args.get('product_id')}')" - ) - _log.info("ended writing data into delta stats table") - except Exception as e: - raise SparkExpectationsMiscException( - f"error occurred while saving data into delta stats table {e}" - ) diff --git a/spark_expectations/sinks/plugins/kafka_writer.py b/spark_expectations/sinks/plugins/kafka_writer.py index b0d75953..75a4c4b7 100644 --- a/spark_expectations/sinks/plugins/kafka_writer.py +++ b/spark_expectations/sinks/plugins/kafka_writer.py @@ -34,7 +34,7 @@ def writer( # } if _write_args.pop("enable_se_streaming"): - _log.info("started write stats data into nsp stats topic") + _log.info("started write stats data into kafka stats topic") df: DataFrame = _write_args.get("stats_df") @@ -42,9 +42,9 @@ def writer( "append" ).options(**_write_args.get("kafka_write_options")).save() - _log.info("ended writing stats data into nsp stats topic") + _log.info("ended writing stats data into kafka stats topic") except Exception as e: raise SparkExpectationsMiscException( - f"error occurred while saving data into NSP {e}" + f"error occurred while saving data into kafka {e}" ) diff --git a/spark_expectations/sinks/utils/collect_statistics.py b/spark_expectations/sinks/utils/collect_statistics.py index edfef3e8..36d70845 100644 --- a/spark_expectations/sinks/utils/collect_statistics.py +++ b/spark_expectations/sinks/utils/collect_statistics.py @@ -12,7 +12,6 @@ class SparkExpectationsCollectStatistics: This class implements logging statistics on success and failure """ - product_id: str _context: SparkExpectationsContext _writer: SparkExpectationsWriter diff --git a/spark_expectations/sinks/utils/writer.py b/spark_expectations/sinks/utils/writer.py index 5e88db70..5c0fc7e3 100644 --- a/spark_expectations/sinks/utils/writer.py +++ b/spark_expectations/sinks/utils/writer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Optional, Tuple, List, Any +from typing import Dict, Optional, Tuple, List from datetime import datetime from pyspark.sql import DataFrame from pyspark.sql.functions import ( @@ -11,13 +11,13 @@ round as sql_round, create_map, explode, + to_json, ) from spark_expectations import _log from spark_expectations.core.exceptions import ( SparkExpectationsUserInputOrConfigInvalidException, SparkExpectationsMiscException, ) -from spark_expectations.core import get_spark_session from spark_expectations.secrets import SparkExpectationsSecretsBackend from spark_expectations.utils.udf import remove_empty_maps from spark_expectations.core.context import SparkExpectationsContext @@ -31,18 +31,13 @@ class SparkExpectationsWriter: This class implements/supports writing data into the sink system """ - product_id: str _context: SparkExpectationsContext def __post_init__(self) -> None: - self.spark = get_spark_session() + self.spark = self._context.spark def save_df_as_table( - self, - df: DataFrame, - table_name: str, - spark_conf: Dict[str, str], - options: Dict[str, str], + self, df: DataFrame, table_name: str, config: dict, stats_table: bool = False ) -> None: """ This function takes a dataframe and writes into a table @@ -50,8 +45,8 @@ def save_df_as_table( Args: df: Provide the dataframe which need to be written as a table table_name: Provide the table name to which the dataframe need to be written to - spark_conf: Provide the spark conf that need to be set on the SparkSession - options: Provide the options that need to be used while writing the table + config: Provide the config to write the dataframe into the table + stats_table: Provide if this is for writing stats table Returns: None: @@ -59,62 +54,47 @@ def save_df_as_table( """ try: print("run date ", self._context.get_run_date) - for key, value in spark_conf.items(): - self.spark.conf.set(key, value) - _df = df.withColumn( - self._context.get_run_id_name, lit(f"{self._context.get_run_id}") - ).withColumn( - self._context.get_run_date_name, - to_timestamp(lit(f"{self._context.get_run_date}")), - ) + if not stats_table: + df = df.withColumn( + self._context.get_run_id_name, lit(f"{self._context.get_run_id}") + ).withColumn( + self._context.get_run_date_name, + to_timestamp( + lit(f"{self._context.get_run_date}"), "yyyy-MM-dd HH:mm:ss" + ), + ) _log.info("_save_df_as_table started") - _df.write.saveAsTable(name=table_name, **options) - self.spark.sql( - f"ALTER TABLE {table_name} SET TBLPROPERTIES ('product_id' = '{self.product_id}')" - ) - _log.info("finished writing records to table: %s", table_name) - except Exception as e: - raise SparkExpectationsUserInputOrConfigInvalidException( - f"error occurred while writing data in to the table {e}" - ) - - def write_df_to_table( - self, - df: DataFrame, - table: str, - spark_conf: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, str]] = None, - ) -> None: - """ - This function takes in a dataframe which has dq results publish it - - Args: - df: Provide a dataframe to write the records to a table. - table : Provide the full original table name into which the data need to be written to - spark_conf: Provide the spark conf, if you want to set/override the configuration - options: Provide the options, if you want to override the default. - default options available are - {"mode": "append", "format": "delta"} - - Returns: - None: + _df_writer = df.write + + if config["mode"] is not None: + _df_writer = _df_writer.mode(config["mode"]) + if config["format"] is not None: + _df_writer = _df_writer.format(config["format"]) + if config["partitionBy"] is not None and config["partitionBy"] != []: + _df_writer = _df_writer.partitionBy(config["partitionBy"]) + if config["sortBy"] is not None and config["sortBy"] != []: + _df_writer = _df_writer.sortBy(config["sortBy"]) + if config["bucketBy"] is not None and config["bucketBy"] != {}: + bucket_by_config = config["bucketBy"] + _df_writer = _df_writer.bucketBy( + bucket_by_config["numBuckets"], bucket_by_config["colName"] + ) + if config["options"] is not None and config["options"] != {}: + _df_writer = _df_writer.options(**config["options"]) + + if config["format"] == "bigquery": + _df_writer.option("table", table_name).save() + else: + _df_writer.saveAsTable(name=table_name) + self.spark.sql( + f"ALTER TABLE {table_name} SET TBLPROPERTIES ('product_id' = '{self._context.product_id}')" + ) + _log.info("finished writing records to table: %s,", table_name) - """ - try: - _spark_conf = ( - {**{"spark.sql.session.timeZone": "Etc/UTC"}, **spark_conf} - if spark_conf - else {"spark.sql.session.timeZone": "Etc/UTC"} - ) - _options = ( - {**{"mode": "append", "format": "delta"}, **options} - if options - else {"mode": "append", "format": "delta"} - ) - self.save_df_as_table(df, table, _spark_conf, _options) except Exception as e: - raise SparkExpectationsMiscException( - f"error occurred while saving the data into the table {e}" + raise SparkExpectationsUserInputOrConfigInvalidException( + f"error occurred while writing data in to the table - {table_name}: {e}" ) def write_error_stats(self) -> None: @@ -122,12 +102,7 @@ def write_error_stats(self) -> None: This functions takes the stats table and write it into error table Args: - table_name: Provide the full table name to which the dq stats will be written to - input_count: Provide the original input dataframe count - error_count: Provide the error record count - output_count: Provide the output dataframe count - source_agg_dq_result: source aggregated dq results - final_agg_dq_result: final aggregated dq results + config: Provide the config to write the dataframe into the table Returns: None: @@ -137,13 +112,6 @@ def write_error_stats(self) -> None: self.spark.conf.set("spark.sql.session.timeZone", "Etc/UTC") from datetime import date - # table_name: str, - # input_count: int, - # error_count: int = 0, - # output_count: int = 0, - # source_agg_dq_result: Optional[List[Dict[str, str]]] = None, - # final_agg_dq_result: Optional[List[Dict[str, str]]] = None, - table_name: str = self._context.get_table_name input_count: int = self._context.get_input_count error_count: int = self._context.get_error_count @@ -163,7 +131,7 @@ def write_error_stats(self) -> None: error_stats_data = [ ( - self.product_id, + self._context.product_id, table_name, input_count, error_count, @@ -301,59 +269,74 @@ def write_error_stats(self) -> None: "Writing metrics to the stats table: %s, started", self._context.get_dq_stats_table_name, ) + if self._context.get_stats_table_writer_config["format"] == "bigquery": + df = df.withColumn("dq_rules", to_json(df["dq_rules"])) - _se_stats_dict = self._context.get_se_streaming_stats_dict + self.save_df_as_table( + df, + self._context.get_dq_stats_table_name, + config=self._context.get_stats_table_writer_config, + stats_table=True, + ) + + _log.info( + "Writing metrics to the stats table: %s, ended", + {self._context.get_dq_stats_table_name}, + ) - secret_handler = SparkExpectationsSecretsBackend(secret_dict=_se_stats_dict) + # TODO check if streaming_stats is set to off, if it's enabled only then this should run - kafka_write_options: dict = ( - { - "kafka.bootstrap.servers": "localhost:9092", - "topic": self._context.get_se_streaming_stats_topic_name, - "failOnDataLoss": "true", - } - if self._context.get_env == "local" - else ( + _se_stats_dict = self._context.get_se_streaming_stats_dict + if _se_stats_dict["se.enable.streaming"]: + secret_handler = SparkExpectationsSecretsBackend( + secret_dict=_se_stats_dict + ) + kafka_write_options: dict = ( { - "kafka.bootstrap.servers": f"{secret_handler.get_secret(self._context.get_server_url_key)}", - "kafka.security.protocol": "SASL_SSL", - "kafka.sasl.mechanism": "OAUTHBEARER", - "kafka.sasl.jaas.config": "kafkashaded.org.apache.kafka.common.security.oauthbearer." - "OAuthBearerLoginModule required oauth.client.id=" - f"'{secret_handler.get_secret(self._context.get_client_id)}' " - + "oauth.client.secret=" - f"'{secret_handler.get_secret(self._context.get_token)}' " - "oauth.token.endpoint.uri=" - f"'{secret_handler.get_secret(self._context.get_token_endpoint_url)}'; ", - "kafka.sasl.login.callback.handler.class": "io.strimzi.kafka.oauth.client" - ".JaasClientOauthLoginCallbackHandler", - "topic": ( - self._context.get_se_streaming_stats_topic_name - if self._context.get_env == "local" - else secret_handler.get_secret(self._context.get_topic_name) - ), + "kafka.bootstrap.servers": "localhost:9092", + "topic": self._context.get_se_streaming_stats_topic_name, + "failOnDataLoss": "true", } - if bool(_se_stats_dict[user_config.se_enable_streaming]) - else {} + if self._context.get_env == "local" + else ( + { + "kafka.bootstrap.servers": f"{secret_handler.get_secret(self._context.get_server_url_key)}", + "kafka.security.protocol": "SASL_SSL", + "kafka.sasl.mechanism": "OAUTHBEARER", + "kafka.sasl.jaas.config": "kafkashaded.org.apache.kafka.common.security.oauthbearer." + "OAuthBearerLoginModule required oauth.client.id=" + f"'{secret_handler.get_secret(self._context.get_client_id)}' " + + "oauth.client.secret=" + f"'{secret_handler.get_secret(self._context.get_token)}' " + "oauth.token.endpoint.uri=" + f"'{secret_handler.get_secret(self._context.get_token_endpoint_url)}'; ", + "kafka.sasl.login.callback.handler.class": "io.strimzi.kafka.oauth.client" + ".JaasClientOauthLoginCallbackHandler", + "topic": ( + self._context.get_se_streaming_stats_topic_name + if self._context.get_env == "local" + else secret_handler.get_secret( + self._context.get_topic_name + ) + ), + } + ) ) - ) - _sink_hook.writer( - _write_args={ - "product_id": self.product_id, - "enable_se_streaming": _se_stats_dict[ - user_config.se_enable_streaming - ], - "table_name": self._context.get_dq_stats_table_name, - "kafka_write_options": kafka_write_options, - "stats_df": df, - } - ) - - _log.info( - "Writing metrics to the stats table: %s, ended", - self._context.get_dq_stats_table_name, - ) + _sink_hook.writer( + _write_args={ + "product_id": self._context.product_id, + "enable_se_streaming": _se_stats_dict[ + user_config.se_enable_streaming + ], + "kafka_write_options": kafka_write_options, + "stats_df": df, + } + ) + else: + _log.info( + "Streaming stats to kafka is disabled, hence skipping writing to kafka" + ) except Exception as e: raise SparkExpectationsMiscException( @@ -361,25 +344,10 @@ def write_error_stats(self) -> None: ) def write_error_records_final( - self, - df: DataFrame, - error_table: str, - rule_type: str, - spark_conf: Optional[Dict[str, str]] = None, - options: Optional[Dict[str, str]] = None, + self, df: DataFrame, error_table: str, rule_type: str ) -> Tuple[int, DataFrame]: try: _log.info("_write_error_records_final started") - _spark_conf = ( - {**{"spark.sql.session.timeZone": "Etc/UTC"}, **spark_conf} - if spark_conf - else {"spark.sql.session.timeZone": "Etc/UTC"} - ) - _options = ( - {**{"mode": "append", "format": "delta"}, **options} - if options - else {"mode": "append", "format": "delta"} - ) failed_records = [ f"size({dq_column}) != 0" @@ -420,7 +388,11 @@ def write_error_records_final( error_df = df.filter(f"size(meta_{rule_type}_results) != 0") self._context.print_dataframe_with_debugger(error_df) - self.save_df_as_table(error_df, error_table, _spark_conf, _options) + self.save_df_as_table( + error_df, + error_table, + self._context.get_target_and_error_table_writer_config, + ) _error_count = error_df.count() if _error_count > 0: diff --git a/spark_expectations/utils/actions.py b/spark_expectations/utils/actions.py index 7ef2cb9b..81b84521 100644 --- a/spark_expectations/utils/actions.py +++ b/spark_expectations/utils/actions.py @@ -40,10 +40,11 @@ def get_rule_is_active( ) -> bool: """ Args: - rule: - _rule_tye_name: - _source_dq_enabled: - _target_dq_enabled: + _context: SparkExpectationsContext class object + rule: dict with rule properties + _rule_type_name: which determines the type of the rule + _source_dq_enabled: Mark it as True when dq running for source dataframe + _target_dq_enabled: Mark it as True when dq running for target dataframe Returns: @@ -93,9 +94,10 @@ def create_rules_map(_rule_map: Dict[str, str]) -> Any: ] ) + @staticmethod def create_agg_dq_results( - self, _context: SparkExpectationsContext, _df: DataFrame, _rule_type_name: str - ) -> List[Dict[str, str]]: + _context: SparkExpectationsContext, _df: DataFrame, _rule_type_name: str + ) -> Optional[List[Dict[str, str]]]: """ This function helps to collect the aggregation results in to the list Args: @@ -109,13 +111,15 @@ def create_agg_dq_results( """ try: - return ( - _df.first()[f"meta_{_rule_type_name}_results"] - if _df + first_row = _df.first() + if ( + first_row is not None and f"meta_{_rule_type_name}_results" in _df.columns - and len(_df.first()[f"meta_{_rule_type_name}_results"]) > 0 - else None - ) + ): + meta_results = first_row[f"meta_{_rule_type_name}_results"] + if meta_results is not None and len(meta_results) > 0: + return meta_results + return None except Exception as e: raise SparkExpectationsMiscException( f"error occurred while running create agg dq results {e}" @@ -138,6 +142,8 @@ def run_dq_rules( df: Input dataframe on which data quality rules need to be applied expectations: Provide the dict which has all the rules rule_type: identifier for the type of rule to be applied in processing + _source_dq_enabled: Mark it as True when dq running for source dataframe + _target_dq_enabled: Mark it as True when dq running for target dataframe Returns: DataFrame: Returns a dataframe with all the rules run the input dataframe @@ -225,7 +231,6 @@ def run_dq_rules( def action_on_rules( _context: SparkExpectationsContext, _df_dq: DataFrame, - _table_name: str, _input_count: int, _error_count: int = 0, _output_count: int = 0, @@ -241,7 +246,6 @@ def action_on_rules( Args: _context: Provide SparkExpectationsContext _df_dq: Input dataframe on which data quality rules need to be applied - table_name: error table name _input_count: input dataset count _error_count: error count in the dataset _output_count: output count in the dataset @@ -251,11 +255,6 @@ def action_on_rules( _final_agg_dq_flag: Mark it as True when dq running for agg level expectations on final dataframe _source_query_dq_flag: Mark it as True when dq running for query level expectations on source dataframe _final_query_dq_flag: Mark it as True when dq running for query level expectations on final dataframe - _action_on: perform action on different stages in dq - _source_agg_dq_result: source aggregated data quality result - _final_agg_dq_result: final aggregated data quality result - _source_query_dq_result: source query based data quality result - _final_query_dq_result: final query based data quality result Returns: DataFrame: Returns a dataframe after dropping the error from the dataset diff --git a/spark_expectations/utils/reader.py b/spark_expectations/utils/reader.py index 2f298c35..bb42d0fb 100644 --- a/spark_expectations/utils/reader.py +++ b/spark_expectations/utils/reader.py @@ -1,17 +1,11 @@ import os -from typing import Optional, Union, List, Dict +from typing import Optional, Union, Dict from dataclasses import dataclass -# from cerberus.client import CerberusClient from pyspark.sql import DataFrame -from pyspark.sql.functions import ( - col, -) from spark_expectations.core.context import SparkExpectationsContext from spark_expectations.config.user_config import Constants as user_config -from spark_expectations.core import get_spark_session from spark_expectations.core.exceptions import ( - SparkExpectationsUserInputOrConfigInvalidException, SparkExpectationsMiscException, ) @@ -22,11 +16,10 @@ class SparkExpectationsReader: This class implements/supports reading data from source system """ - product_id: str _context: SparkExpectationsContext def __post_init__(self) -> None: - self.spark = get_spark_session() + self.spark = self._context.spark def set_notification_param( self, notification: Optional[Dict[str, Union[int, str, bool]]] = None @@ -124,74 +117,31 @@ def set_notification_param( f"error occurred while reading notification configurations {e}" ) - def get_rules_dlt( + def get_rules_from_df( self, - product_rules_table: str, - table_name: str, - action: Union[list, str], + rules_df: DataFrame, + target_table: str, + is_dlt: bool = False, tag: Optional[str] = None, - ) -> dict: - """ - This function supports creating a dict of expectations that is acceptable by DLT - Args: - product_rules_table: Provide the full table name, which has your data quality rules - table_name: Provide the full table name for which the data quality rules are being run - action: Provide the action which you want to filter from rules table. Value should only from one of these - - "fail" or "drop" or "ignore" or provide the needed in a list ["fail", "drop", "ignore"] - tag: Provide the KPI for which you are running the data quality rule - - Returns: - dict: returns a dict with key as 'rule' and 'expectation' as value - """ - try: - _actions: List[str] = [].append(action) if isinstance(action, str) else action # type: ignore - _expectations: dict = {} - _rules_df: DataFrame = self.spark.sql( - f""" - select rule, tag, expectation from {product_rules_table} - where product_id='{self.product_id}' and table_name='{table_name}' and - action_if_failed in ('{"', '".join(_actions)}') - """ - ) - if tag: - for row in _rules_df.filter(col("tag") == tag).collect(): - _expectations[row["rule"]] = row["expectation"] - else: - for row in _rules_df.collect(): - _expectations[row["rule"]] = row["expectation"] - return _expectations - - except Exception as e: - raise SparkExpectationsUserInputOrConfigInvalidException( - f"error occurred while reading or getting rules from the rules table {e}" - ) - - def get_rules_from_table( - self, - product_rules_table: str, - dq_stats_table_name: str, - target_table_name: str, - actions_if_failed: Optional[List[str]] = None, - ) -> dict: + ) -> tuple[dict, dict]: """ This function fetches the data quality rules from the table and return it as a dictionary Args: - product_rules_table: Provide the full table name, which has your data quality rules - table_name: Provide the full table name for which the data quality rules are being run - dq_stats_table_name: Provide the table name, to which Data Quality Stats have to be written to - actions_if_failed: Provide the list of actions in ["fail", "drop", 'ignore'], which need to be applied on a - particular row if a rule failed + rules_df: DataFrame which has your data quality rules + target_table: Provide the full table name for which the data quality rules are being run + is_dlt: True if this for fetching the rules for dlt job + tag: If is_dlt is True, provide the KPI for which you are running the data quality rule Returns: - dict: The dict with table and rules as keys + tuple: returns a tuple of two dictionaries with key as 'rule_type' and 'rules_table_row' as value in + expectations. dict, and key as 'dq_stage_setting' and 'boolean' as value in rules_execution_settings + dict """ try: - self._context.set_dq_stats_table_name(dq_stats_table_name) - - self._context.set_final_table_name(target_table_name) - self._context.set_error_table_name(f"{target_table_name}_error") - self._context.set_table_name(target_table_name) + self._context.set_final_table_name(target_table) + self._context.set_error_table_name(f"{target_table}_error") + self._context.set_table_name(target_table) self._context.set_env(os.environ.get("SPARKEXPECTATIONS_ENV")) self._context.reset_num_agg_dq_rules() @@ -199,65 +149,90 @@ def get_rules_from_table( self._context.reset_num_row_dq_rules() self._context.reset_num_query_dq_rules() - _actions_if_failed: List[str] = actions_if_failed or [ - "fail", - "drop", - "ignore", - ] - - _rules_df: DataFrame = self.spark.sql( - f""" - select * from {product_rules_table} where product_id='{self.product_id}' - and table_name='{target_table_name}' - and action_if_failed in ('{"', '".join(_actions_if_failed)}') and is_active=true - """ + _rules_df = rules_df.filter( + (rules_df.product_id == self._context.product_id) + & (rules_df.table_name == target_table) + & rules_df.is_active ) + self._context.print_dataframe_with_debugger(_rules_df) _expectations: dict = {} - for row in _rules_df.collect(): - column_map = { - "product_id": row["product_id"], - "table_name": row["table_name"], - "rule_type": row["rule_type"], - "rule": row["rule"], - "column_name": row["column_name"], - "expectation": row["expectation"], - "action_if_failed": row["action_if_failed"], - "enable_for_source_dq_validation": row[ - "enable_for_source_dq_validation" - ], - "enable_for_target_dq_validation": row[ - "enable_for_target_dq_validation" - ], - "tag": row["tag"], - "description": row["description"], - "enable_error_drop_alert": row["enable_error_drop_alert"], - "error_drop_threshold": row["error_drop_threshold"], - } - - if f"{row['rule_type']}_rules" in _expectations: - _expectations[f"{row['rule_type']}_rules"].append(column_map) + _rules_execution_settings: dict = {} + if is_dlt: + if tag: + for row in _rules_df.filter(_rules_df.tag == tag).collect(): + _expectations[row["rule"]] = row["expectation"] else: - _expectations[f"{row['rule_type']}_rules"] = [column_map] + for row in _rules_df.collect(): + _expectations[row["rule"]] = row["expectation"] + else: + for row in _rules_df.collect(): + column_map = { + "product_id": row["product_id"], + "table_name": row["table_name"], + "rule_type": row["rule_type"], + "rule": row["rule"], + "column_name": row["column_name"], + "expectation": row["expectation"], + "action_if_failed": row["action_if_failed"], + "enable_for_source_dq_validation": row[ + "enable_for_source_dq_validation" + ], + "enable_for_target_dq_validation": row[ + "enable_for_target_dq_validation" + ], + "tag": row["tag"], + "description": row["description"], + "enable_error_drop_alert": row["enable_error_drop_alert"], + "error_drop_threshold": row["error_drop_threshold"], + } + + if f"{row['rule_type']}_rules" in _expectations: + _expectations[f"{row['rule_type']}_rules"].append(column_map) + else: + _expectations[f"{row['rule_type']}_rules"] = [column_map] + + # count the rules enabled for the current run + if row["rule_type"] == self._context.get_row_dq_rule_type_name: + self._context.set_num_row_dq_rules() + elif row["rule_type"] == self._context.get_agg_dq_rule_type_name: + self._context.set_num_agg_dq_rules( + row["enable_for_source_dq_validation"], + row["enable_for_target_dq_validation"], + ) + elif row["rule_type"] == self._context.get_query_dq_rule_type_name: + self._context.set_num_query_dq_rules( + row["enable_for_source_dq_validation"], + row["enable_for_target_dq_validation"], + ) - # count the rules enabled for the current run - if row["rule_type"] == self._context.get_row_dq_rule_type_name: - self._context.set_num_row_dq_rules() - elif row["rule_type"] == self._context.get_agg_dq_rule_type_name: - self._context.set_num_agg_dq_rules( - row["enable_for_source_dq_validation"], - row["enable_for_target_dq_validation"], + _expectations["target_table_name"] = target_table + _rules_execution_settings = self._get_rules_execution_settings( + _rules_df ) - elif row["rule_type"] == self._context.get_query_dq_rule_type_name: - self._context.set_num_query_dq_rules( - row["enable_for_source_dq_validation"], - row["enable_for_target_dq_validation"], - ) - - _expectations["target_table_name"] = target_table_name - return _expectations + return _expectations, _rules_execution_settings except Exception as e: raise SparkExpectationsMiscException( f"error occurred while retrieving rules list from the table {e}" ) + + def _get_rules_execution_settings(self, rules_df: DataFrame) -> dict: + rules_df.createOrReplaceTempView("rules_view") + df = self.spark.sql( + """SELECT + MAX(CASE WHEN rule_type = 'row_dq' THEN True ELSE False END) AS row_dq, + MAX(CASE WHEN rule_type = 'agg_dq' AND enable_for_source_dq_validation = true + THEN True ELSE False END) AS source_agg_dq, + MAX(CASE WHEN rule_type = 'query_dq' AND enable_for_source_dq_validation = true + THEN True ELSE False END) AS source_query_dq, + MAX(CASE WHEN rule_type = 'agg_dq' AND enable_for_target_dq_validation = true + THEN True ELSE False END) AS target_agg_dq, + MAX(CASE WHEN rule_type = 'query_dq' AND enable_for_target_dq_validation = true + THEN True ELSE False END) AS target_query_dq + FROM rules_view""" + ) + # convert the df to python dictionary as it has only one row + rule_execution_settings = df.collect()[0].asDict() + self.spark.catalog.dropTempView("rules_view") + return rule_execution_settings diff --git a/spark_expectations/utils/regulate_flow.py b/spark_expectations/utils/regulate_flow.py index a7608652..c0e84304 100644 --- a/spark_expectations/utils/regulate_flow.py +++ b/spark_expectations/utils/regulate_flow.py @@ -21,8 +21,8 @@ class SparkExpectationsRegulateFlow: product_id: str + @staticmethod def execute_dq_process( - self, _context: SparkExpectationsContext, _actions: SparkExpectationsActions, _writer: SparkExpectationsWriter, @@ -30,20 +30,17 @@ def execute_dq_process( expectations: Dict[str, List[dict]], table_name: str, _input_count: int = 0, - spark_conf: Optional[Dict[str, Any]] = None, - options_error_table: Optional[Dict[str, str]] = None, ) -> Any: """ This functions takes required static variable and returns the function Args: + _context: SparkExpectationsContext class object _actions: SparkExpectationsActions class object _writer: SparkExpectationsWriter class object + _notification: SparkExpectationsNotify class object expectations: expectations dictionary which contains rules table_name: name of the table _input_count: number of records in the source dataframe - spark_conf: spark configurations(which is optional) - options_error_table: spark configurations to write data into the error table(which is optional) - Returns: Any: returns function @@ -70,7 +67,7 @@ def func_process( final_agg_dq_flag: default false, Mark True tp process agg level data quality on final dataframe source_query_dq_flag: default false, Mark True tp process query level data quality on source dataframe final_query_dq_flag: default false, Mark True tp process query level data quality on final dataframe - error_count: number of records error records default zero) + error_count: number of records error records (default zero) output_count: number of output records from expectations (default zero) Returns: @@ -79,7 +76,6 @@ def func_process( """ try: - _df_dq: Optional[DataFrame] = None _error_df: Optional[DataFrame] = None _error_count: int = error_count @@ -97,7 +93,7 @@ def func_process( "The data quality dataframe is getting created for expectations" ) - _df_dq = _actions.run_dq_rules( + _df_dq: DataFrame = _actions.run_dq_rules( _context, df, expectations, @@ -128,8 +124,6 @@ def func_process( _df_dq, f"{table_name}_error", _context.get_row_dq_rule_type_name, - spark_conf, - options_error_table, ) if _context.get_summarised_row_dq_res: _notification.notify_rules_exceeds_threshold(expectations) @@ -153,7 +147,6 @@ def func_process( df = _actions.action_on_rules( _context, _error_df if row_dq_flag else _df_dq, - table_name, _input_count, _error_count=_error_count, _output_count=output_count, diff --git a/tests/config/test_user_config.py b/tests/config/test_user_config.py index 0dd9ee1d..25994a85 100644 --- a/tests/config/test_user_config.py +++ b/tests/config/test_user_config.py @@ -15,20 +15,6 @@ def test_constants(): assert user_config.se_notifications_slack_webhook_url == "spark.expectations.notifications.slack.webhook_url" - assert user_config.se_agg_dq == "agg_dq" - - assert user_config.se_source_agg_dq == "source_agg_dq" - - assert user_config.se_final_agg_dq == "final_agg_dq" - - assert user_config.se_query_dq == "query_dq" - - assert user_config.se_source_query_dq == "source_query_dq" - - assert user_config.se_final_query_dq == "final_query_dq" - - assert user_config.se_target_table_view == "target_table_view" - assert user_config.se_notifications_on_start == "spark.expectations.notifications.on_start" assert user_config.se_notifications_on_completion == "spark.expectations.notifications.on.completion" diff --git a/tests/core/test_context.py b/tests/core/test_context.py index e53fa593..eb799d8d 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -6,13 +6,13 @@ from spark_expectations.config.user_config import Constants as user_config from spark_expectations.core.context import SparkExpectationsContext from spark_expectations.core.exceptions import SparkExpectationsMiscException - - +from datetime import datetime, date +spark = get_spark_session() @patch("spark_expectations.core.context.uuid1") def test_context_init(mock_uuid): # Test that the product_id is set correctly mock_uuid.return_value = "hghg-gjgu-jgj" - context = SparkExpectationsContext(product_id='test_product') + context = SparkExpectationsContext(product_id='test_product', spark=spark) assert context.product_id == 'test_product' # Test that the run_id is set correctly @@ -25,7 +25,7 @@ def test_context_init(mock_uuid): def test_context_properties(): # Test that the getter properties return the correct values - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._run_id = 'test_run_id' context._run_date = 'test_run_date' context._dq_stats_table_name = 'test_dq_stats_table' @@ -53,7 +53,7 @@ def test_context_properties(): context._input_count = 100 context._error_count = 10 context._output_count = 90 - context._nsp_stats_topic_name = "spark_expectations_stats_topic" + context._kafka_stats_topic_name = "spark_expectations_stats_topic" context._source_agg_dq_result = [ {"action_if_failed": "ignore", "rule_type": "agg_dq", "rule_name": "sum_of_salary_threshold", "rule": "sum(salary)>100"}] @@ -70,10 +70,10 @@ def test_context_properties(): context._cerberus_url = "https://xyz" context._cerberus_cred_path = "spark-expectations/credentials" context._cerberus_token = "xxx" - # context._nsp_bootstrap_server_url = "https://boostarp/server" - # context._nsp_secret = "xxxx" - # context._nsp_token_endpoint_uri = "https://token_uri" - # context._nsp_client_id = "spark-expectations" + # context._kafka_bootstrap_server_url = "https://boostarp/server" + # context._kafka_secret = "xxxx" + # context._kafka_token_endpoint_uri = "https://token_uri" + # context._kafka_client_id = "spark-expectations" context._run_id_name = "run_id" context._run_date_name = "run_date" @@ -84,7 +84,7 @@ def test_context_properties(): context._debugger_mode = False - context._supported_df_query_dq = get_spark_session().createDataFrame( + context._supported_df_query_dq = spark.createDataFrame( [ { "spark_expectations_query_check": "supported_place_holder_dataset_to_run_query_check" @@ -118,7 +118,7 @@ def test_context_properties(): "action_if_failed": "fail", "failed_row_count": 4}] - context._nsp_row_dq_res_topic_name = "abc" + context._kafka_row_dq_res_topic_name = "abc" context._se_streaming_stats_dict = {"a": "b", "c": "d"} context._se_streaming_stats_topic_name = "test_topic" @@ -150,7 +150,7 @@ def test_context_properties(): assert context._input_count == 100 assert context._error_count == 10 assert context._output_count == 90 - assert context._nsp_stats_topic_name == "spark_expectations_stats_topic" + assert context._kafka_stats_topic_name == "spark_expectations_stats_topic" assert context._source_agg_dq_result == [ {"action_if_failed": "ignore", "rule_type": "agg_dq", "rule_name": "sum_of_salary_threshold", "rule": "sum(salary)>100"}] @@ -170,16 +170,16 @@ def test_context_properties(): assert context.get_cerberus_cred_path == "spark-expectations/credentials" assert context._cerberus_token == "xxx" assert context.get_cerberus_token == "xxx" - # assert context._nsp_bootstrap_server_url == "https://boostarp/server" - # assert context._nsp_secret == "xxxx" - # assert context._nsp_token_endpoint_uri == "https://token_uri" - # assert context._nsp_client_id == "spark-expectations" + # assert context._kafka_bootstrap_server_url == "https://boostarp/server" + # assert context._kafka_secret == "xxxx" + # assert context._kafka_token_endpoint_uri == "https://token_uri" + # assert context._kafka_client_id == "spark-expectations" assert context._run_id_name == "run_id" assert context._run_date_name == "run_date" assert context._run_date_time_name == "run_date_time" - assert context._supported_df_query_dq == get_spark_session().createDataFrame( + assert context._supported_df_query_dq == spark.createDataFrame( [ { "spark_expectations_query_check": "supported_place_holder_dataset_to_run_query_check" @@ -215,74 +215,74 @@ def test_context_properties(): assert context._source_query_dq_status == "Passed" assert context._final_query_dq_status == "Skipped" - assert context._nsp_row_dq_res_topic_name == "abc" + assert context._kafka_row_dq_res_topic_name == "abc" assert context._se_streaming_stats_dict == {"a": "b", "c": "d"} assert context._se_streaming_stats_topic_name == "test_topic" def test_set_dq_stats_table_name(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_dq_stats_table_name("dq_stats_table_name") assert context._dq_stats_table_name == "dq_stats_table_name" assert context.get_dq_stats_table_name == "dq_stats_table_name" def test_set_final_table_name(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_final_table_name("final_table_name") assert context._final_table_name == "final_table_name" assert context.get_final_table_name == "final_table_name" def test_error_table_name(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_error_table_name("error_table_name") assert context._error_table_name == "error_table_name" assert context.get_error_table_name == "error_table_name" def test_row_dq_rule_type_name(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._row_dq_rule_type_name = "row_dq1" context.get_row_dq_rule_type_name == "row_dq1" def test_agg_dq_rule_type_name(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._agg_dq_rule_type_name = "row_dq1" context.get_agg_dq_rule_type_name == "row_dq1" def test_set_source_agg_dq_status(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_source_agg_dq_status("Passed") assert context._source_agg_dq_status == "Passed" assert context.get_source_agg_dq_status == "Passed" def test_set_row_dq_status(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_row_dq_status("Failed") assert context._row_dq_status == "Failed" assert context.get_row_dq_status == "Failed" def test_set_final_agg_dq_status(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_final_agg_dq_status("Skipped") assert context._final_agg_dq_status == "Skipped" assert context.get_final_agg_dq_status == "Skipped" def test_set_dq_run_status(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_dq_run_status("Passed") assert context._dq_run_status == "Passed" assert context.get_dq_run_status == "Passed" def test_get_source_agg_dq_status_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._source_agg_dq_status = None with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -291,7 +291,7 @@ def test_get_source_agg_dq_status_exception(): def test_get_row_dq_status_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._row_dq_status = None with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -300,7 +300,7 @@ def test_get_row_dq_status_exception(): def test_get_final_agg_dq_status_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._final_agg_dq_status = None with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -309,7 +309,7 @@ def test_get_final_agg_dq_status_exception(): def test_get_dq_run_status_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._dq_run_status = None with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -318,7 +318,7 @@ def test_get_dq_run_status_exception(): def test_get_row_dq_rule_type_name_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._row_dq_rule_type_name = None with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -327,7 +327,7 @@ def test_get_row_dq_rule_type_name_exception(): def test_get_agg_dq_rule_type_name_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._agg_dq_rule_type_name = None with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -335,17 +335,37 @@ def test_get_agg_dq_rule_type_name_exception(): context.get_agg_dq_rule_type_name -def test_get_query_dq_rule_type_name_exception(): - context = SparkExpectationsContext(product_id="product1") - context._query_dq_rule_type_name = None - with pytest.raises(SparkExpectationsMiscException, - match="The spark expectations context is not set completely, please assign " - "'_query_dq_rule_type_name' before \n accessing it"): - context.get_query_dq_rule_type_name + +def test_set_source_query_dq_result(): + context = SparkExpectationsContext(product_id="product1", spark=spark) + value = [{'1':'2'}] + context.set_source_query_dq_result(value) + context.get_source_query_dq_result == value + + +def test_set_final_query_dq_result(): + context = SparkExpectationsContext(product_id="product1", spark=spark) + value = [{'1':'2'}] + context.set_final_query_dq_result(value) + context.get_final_query_dq_result == value + + +def test_get_query_dq_rule_type_name(): + context = SparkExpectationsContext(product_id="product1", spark=spark) + values = [None, 'query_dq'] + for value in values: + context._query_dq_rule_type_name = value + if value is None: + with pytest.raises(SparkExpectationsMiscException, + match="The spark expectations context is not set completely, please assign " + "'_query_dq_rule_type_name' before \n accessing it"): + context.get_query_dq_rule_type_name + else: + context.get_query_dq_rule_type_name == value def test_get_dq_stats_table_name_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " "'_dq_stats_table_name' before \n accessing it"): @@ -353,7 +373,7 @@ def test_get_dq_stats_table_name_exception(): def test_get_final_table_name_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " "'_final_table_name' before \n accessing it"): @@ -361,7 +381,7 @@ def test_get_final_table_name_exception(): def test_get_error_table_name_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " "'_error_table_name' before \n accessing it"): @@ -369,14 +389,14 @@ def test_get_error_table_name_exception(): def test_get_config_file_path(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._dq_config_abs_path = "spark_expectations/config/file" assert context.get_config_file_path == "spark_expectations/config/file" def test_get_config_file_path_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._dq_config_abs_path = None with pytest.raises(SparkExpectationsMiscException, match="""The spark expectations context is not set completely, please assign '_dq_config_abs_path' before @@ -385,21 +405,21 @@ def test_get_config_file_path_exception(): def test_set_enable_mail(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_enable_mail(True) assert context._enable_mail is True assert context.get_enable_mail is True def test_set_smtp_server(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_mail_smtp_server("abc") assert context._mail_smtp_server == "abc" assert context.get_mail_smtp_server == "abc" def test_set_smtp_port(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_mail_smtp_port(25) context.set_mail_smtp_port(context._mail_smtp_port) assert context._mail_smtp_port == 25 @@ -407,42 +427,42 @@ def test_set_smtp_port(): def test_set_to_mail(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_to_mail("abc@mail.com, def@mail.com") assert context._to_mail == "abc@mail.com, def@mail.com" assert context.get_to_mail == "abc@mail.com, def@mail.com" def test_set_mail_from(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_mail_from("abc@mail.com") assert context._mail_from == "abc@mail.com" assert context.get_mail_from == "abc@mail.com" def test_set_mail_subject(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_mail_subject("spark expectations") assert context._mail_subject == "spark expectations" assert context.get_mail_subject == "spark expectations" def test_set_enable_slack(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_enable_slack(True) assert context._enable_slack is True assert context.get_enable_slack is True def test_set_slack_webhook_url(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_slack_webhook_url("abcdefghi") assert context._slack_webhook_url == "abcdefghi" assert context.get_slack_webhook_url == "abcdefghi" def test_table_name(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_table_name("test_table") assert context._table_name == "test_table" assert context._table_name == "test_table" @@ -450,7 +470,7 @@ def test_table_name(): def test_set_input_count(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_input_count(100) assert context._input_count == 100 assert context._input_count == 100 @@ -458,28 +478,28 @@ def test_set_input_count(): def test_set_error_count(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_error_count(10) assert context._error_count == 10 assert context.get_error_count == 10 def test_set_output_count(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_output_count(90) assert context._output_count == 90 assert context.get_output_count == 90 -# def test_set_nsp_stats_topic_name(): +# def test_set_kafka_stats_topic_name(): # context = SparkExpectationsContext(product_id="product1") -# context.set_nsp_stats_topic_name("spark_expectations_stats_topic") -# assert context._nsp_stats_topic_name == "spark_expectations_stats_topic" -# assert context.get_nsp_stats_topic_name == "spark_expectations_stats_topic" +# context.set_kafka_stats_topic_name("spark_expectations_stats_topic") +# assert context._kafka_stats_topic_name == "spark_expectations_stats_topic" +# assert context.get_kafka_stats_topic_name == "spark_expectations_stats_topic" -def test_set_nsp_source_agg_dq_result(): - context = SparkExpectationsContext(product_id="product1") +def test_set_kafka_source_agg_dq_result(): + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_source_agg_dq_result([ {"action_if_failed": "ignore", "rule_type": "agg_dq", "rule_name": "sum_of_salary_threshold", "rule": "sum(salary)>100"}]) @@ -491,8 +511,8 @@ def test_set_nsp_source_agg_dq_result(): "rule": "sum(salary)>100"}] -def test_set_nsp_final_agg_dq_result(): - context = SparkExpectationsContext(product_id="product1") +def test_set_kafka_final_agg_dq_result(): + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_final_agg_dq_result([ {"action_if_failed": "ignore", "rule_type": "agg_dq", "rule_name": "sum_of_salary_threshold", "rule": "sum(salary)>100"}]) @@ -505,7 +525,7 @@ def test_set_nsp_final_agg_dq_result(): def test_get_mail_smtp_server_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._mail_smtp_server = None with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -514,7 +534,7 @@ def test_get_mail_smtp_server_exception(): def test_get_mail_smtp_port_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._mail_smtp_port = 0 with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -523,7 +543,7 @@ def test_get_mail_smtp_port_exception(): def test_get_to_mail_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._to_mail = False with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -532,7 +552,7 @@ def test_get_to_mail_exception(): def test_get_mail_subject_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._to_mail = False with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -541,7 +561,7 @@ def test_get_mail_subject_exception(): def test_get_mail_from_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._mail_from = False with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -550,7 +570,7 @@ def test_get_mail_from_exception(): def test_get_slack_webhook_url_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._slack_webhook_url = False with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -559,7 +579,7 @@ def test_get_slack_webhook_url_exception(): def test_get_table_name_expection(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._table_name = "" with pytest.raises(SparkExpectationsMiscException, match="The spark expectations context is not set completely, please assign " @@ -576,51 +596,51 @@ def test_get_table_name_expection(): # context.get_input_count -# def test_get_nsp_stats_topic_name_exception(): +# def test_get_kafka_stats_topic_name_exception(): # context = SparkExpectationsContext(product_id="product1") -# context._nsp_stats_topic_name = None +# context._kafka_stats_topic_name = None # with pytest.raises(SparkExpectationsMiscException, # match="The spark expectations context is not set completely, please assign " -# "'_nsp_stats_topic_name' before \n accessing it"): -# context.get_nsp_stats_topic_name +# "'_kafka_stats_topic_name' before \n accessing it"): +# context.get_kafka_stats_topic_name def test_set_notification_on_start(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_notification_on_start(True) assert context._notification_on_start is True assert context.get_notification_on_start is True def test_set_notification_on_completion(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_notification_on_completion(True) assert context._notification_on_completion is True assert context.get_notification_on_completion is True def test_set_notification_on_fail(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_notification_on_fail(True) assert context._notification_on_fail is True assert context.get_notification_on_fail is True def test_set_env(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._table_name = "dq_spark_staging.test_table" context.set_env("staging") assert context.get_env == "staging" def test_get_env(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._env = "dev1" assert context.get_env == "dev1" def test_get_error_percentage(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._input_count = 100 context._error_count = 50 @@ -628,7 +648,7 @@ def test_get_error_percentage(): def test_get_output_percentage(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._input_count = 100 context._output_count = 50 @@ -636,7 +656,7 @@ def test_get_output_percentage(): def test_get_success_percentage(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._input_count = 100 context._output_count = 50 context._error_count = 25 @@ -645,7 +665,7 @@ def test_get_success_percentage(): def test_get_error_drop_percentage(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._input_count = 100 context._output_count = 50 context._error_count = 25 @@ -654,14 +674,14 @@ def test_get_error_drop_percentage(): def test_set_error_threshold(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_error_drop_threshold(100) assert context._error_drop_threshold == 100 assert context.get_error_drop_threshold == 100 def test_get_error_threshold(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._error_drop_threshold = 0 with pytest.raises(SparkExpectationsMiscException, match="""The spark expectations context is not set completely, please assign '_error_drop_threshold' before @@ -670,7 +690,7 @@ def test_get_error_threshold(): def test_get_cerberus_url_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._cerberus_url = None with pytest.raises(SparkExpectationsMiscException, match="""The spark expectations context is not set completely, please assign '_cerberus_url' before @@ -679,7 +699,7 @@ def test_get_cerberus_url_exception(): def test_get_cerberus_cred_path_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._cerberus_cred_path = None with pytest.raises(SparkExpectationsMiscException, match="""The spark expectations context is not set completely, please assign '_cerberus_cred_path' before @@ -688,7 +708,7 @@ def test_get_cerberus_cred_path_exception(): def test_get_cerberus_token_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._cerberus_token = None with pytest.raises(SparkExpectationsMiscException, match="""The spark expectations context is not set completely, please assign '_cerberus_token' before @@ -696,151 +716,151 @@ def test_get_cerberus_token_exception(): context.get_cerberus_token -# def test_get_nsp_bootstrap_server_url_exception(): +# def test_get_kafka_bootstrap_server_url_exception(): # context = SparkExpectationsContext(product_id="product1") -# context._nsp_bootstrap_server_url = None +# context._kafka_bootstrap_server_url = None # with pytest.raises(SparkExpectationsMiscException, -# match="""The spark expectations context is not set completely, please assign '_nsp_bootstrap_server_url' before +# match="""The spark expectations context is not set completely, please assign '_kafka_bootstrap_server_url' before # accessing it"""): -# context.get_nsp_bootstrap_server_url +# context.get_kafka_bootstrap_server_url -# def test_get_nsp_secret_exception(): +# def test_get_kafka_secret_exception(): # context = SparkExpectationsContext(product_id="product1") -# context._nsp_secret = None +# context._kafka_secret = None # with pytest.raises(SparkExpectationsMiscException, -# match="""The spark expectations context is not set completely, please assign '_nsp_secret' before +# match="""The spark expectations context is not set completely, please assign '_kafka_secret' before # accessing it"""): -# context.get_nsp_secret +# context.get_kafka_secret -# def test_get_nsp_token_endpoint_uri_exception(): +# def test_get_kafka_token_endpoint_uri_exception(): # context = SparkExpectationsContext(product_id="product1") -# context._nsp_token_endpoint_uri = None +# context._kafka_token_endpoint_uri = None # with pytest.raises(SparkExpectationsMiscException, -# match="""The spark expectations context is not set completely, please assign '_nsp_token_endpoint_uri' before +# match="""The spark expectations context is not set completely, please assign '_kafka_token_endpoint_uri' before # accessing it"""): -# context.get_nsp_token_endpoint_uri +# context.get_kafka_token_endpoint_uri -# def test_get_nsp_client_id_exception(): +# def test_get_kafka_client_id_exception(): # context = SparkExpectationsContext(product_id="product1") -# context._nsp_client_id = None +# context._kafka_client_id = None # with pytest.raises(SparkExpectationsMiscException, -# match="""The spark expectations context is not set completely, please assign '_nsp_client_id' before +# match="""The spark expectations context is not set completely, please assign '_kafka_client_id' before # accessing it"""): -# context.get_nsp_client_id +# context.get_kafka_client_id -# def test_set_nsp_bootstrap_server_url(): +# def test_set_kafka_bootstrap_server_url(): # context = SparkExpectationsContext(product_id="product1") -# context.set_nsp_bootstrap_server_url(nsp_bootstrap_server_url="https://boostarp/server") -# assert context._nsp_bootstrap_server_url == "https://boostarp/server" -# assert context.get_nsp_bootstrap_server_url == "https://boostarp/server" +# context.set_kafka_bootstrap_server_url(kafka_bootstrap_server_url="https://boostarp/server") +# assert context._kafka_bootstrap_server_url == "https://boostarp/server" +# assert context.get_kafka_bootstrap_server_url == "https://boostarp/server" -# def test_nsp_secret(): +# def test_kafka_secret(): # context = SparkExpectationsContext(product_id="product1") -# context.set_nsp_secret(nsp_secret="xxx") -# assert context._nsp_secret == "xxx" -# assert context.get_nsp_secret == "xxx" +# context.set_kafka_secret(kafka_secret="xxx") +# assert context._kafka_secret == "xxx" +# assert context.get_kafka_secret == "xxx" -# def test_nsp_token_endpoint_uri(): +# def test_kafka_token_endpoint_uri(): # context = SparkExpectationsContext(product_id="product1") -# context.set_nsp_token_endpoint_uri(nsp_token_endpoint_uri="https://token_uri") -# assert context._nsp_token_endpoint_uri == "https://token_uri" -# assert context.get_nsp_token_endpoint_uri == "https://token_uri" +# context.set_kafka_token_endpoint_uri(kafka_token_endpoint_uri="https://token_uri") +# assert context._kafka_token_endpoint_uri == "https://token_uri" +# assert context.get_kafka_token_endpoint_uri == "https://token_uri" -# def test_set_nsp_client_id(): +# def test_set_kafka_client_id(): # context = SparkExpectationsContext(product_id="product1") -# context.set_nsp_client_id(nsp_client_id="spark-expectations") -# assert context._nsp_client_id == "spark-expectations" -# assert context.get_nsp_client_id == "spark-expectations" +# context.set_kafka_client_id(kafka_client_id="spark-expectations") +# assert context._kafka_client_id == "spark-expectations" +# assert context.get_kafka_client_id == "spark-expectations" def test_set_source_agg_dq_start_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_source_agg_dq_start_time() assert isinstance(context._source_agg_dq_start_time, datetime) def test_set_source_agg_dq_end_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_source_agg_dq_end_time() assert isinstance(context._source_agg_dq_end_time, datetime) def test_set_final_agg_dq_start_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_final_agg_dq_start_time() assert isinstance(context._final_agg_dq_start_time, datetime) def test_set_final_agg_dq_end_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_final_agg_dq_end_time() assert isinstance(context._final_agg_dq_end_time, datetime) def test_set_source_query_dq_start_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_source_query_dq_start_time() assert isinstance(context._source_query_dq_start_time, datetime) def test_set_source_query_dq_end_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_source_query_dq_end_time() assert isinstance(context._source_query_dq_end_time, datetime) def test_set_final_query_dq_start_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_final_query_dq_start_time() assert isinstance(context._final_query_dq_start_time, datetime) def test_set_final_query_dq_end_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_final_query_dq_end_time() assert isinstance(context._final_query_dq_end_time, datetime) def test_set_row_dq_start_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_row_dq_start_time() assert isinstance(context._row_dq_start_time, datetime) def test_set_row_dq_end_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_row_dq_end_time() assert isinstance(context._row_dq_end_time, datetime) def test_set_dq_start_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_dq_start_time() assert isinstance(context._dq_start_time, datetime) def test_set_dq_end_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_dq_end_time() assert isinstance(context._dq_end_time, datetime) def test_get_time_diff(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) assert context.get_time_diff(None, None) == 0.0 now = datetime.now() assert context.get_time_diff(now, now + timedelta(seconds=2)) == 2.0 def test_get_source_agg_dq_run_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) now = datetime.now() context._source_agg_dq_start_time = now context._source_agg_dq_end_time = now + timedelta(seconds=2) @@ -848,7 +868,7 @@ def test_get_source_agg_dq_run_time(): def test_get_final_agg_dq_run_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) now = datetime.now() context._final_agg_dq_start_time = now context._final_agg_dq_end_time = now + timedelta(seconds=2) @@ -856,7 +876,7 @@ def test_get_final_agg_dq_run_time(): def test_get_source_query_dq_run_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) now = datetime.now() context._source_query_dq_start_time = now context._source_query_dq_end_time = now + timedelta(seconds=2) @@ -864,7 +884,7 @@ def test_get_source_query_dq_run_time(): def test_get_final_query_dq_run_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) now = datetime.now() context._final_query_dq_start_time = now context._final_query_dq_end_time = now + timedelta(seconds=2) @@ -872,7 +892,7 @@ def test_get_final_query_dq_run_time(): def test_get_row_dq_run_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) now = datetime.now() context._row_dq_start_time = now context._row_dq_end_time = now + timedelta(seconds=2) @@ -880,49 +900,61 @@ def test_get_row_dq_run_time(): def test_get_dq_run_time(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) now = datetime.now() context._dq_start_time = now context._dq_end_time = now + timedelta(seconds=2) assert context.get_dq_run_time == 2.0 -def test_get_run_id_name_exception(): - context = SparkExpectationsContext(product_id="product1") - context._run_id_name = None - with pytest.raises(SparkExpectationsMiscException, - match="""The spark expectations context is not set completely, please assign '_run_id_name' before - accessing it"""): - context.get_run_id_name - - -def test_get_run_date_name_exception(): - context = SparkExpectationsContext(product_id="product1") - context._run_date_name = None - with pytest.raises(SparkExpectationsMiscException, - match="""The spark expectations context is not set completely, please assign '_run_date_name' before - accessing it"""): - context.get_run_date_name - - -def test_get_run_date_time_name_exception(): - context = SparkExpectationsContext(product_id="product1") - context._run_date_time_name = None - with pytest.raises(SparkExpectationsMiscException, - match="""The spark expectations context is not set completely, please assign '_run_date_time_name' before - accessing it"""): - context.get_run_date_time_name +def test_get_run_id_name(): + context = SparkExpectationsContext(product_id="product1", spark=spark) + values = [None, 'test'] + for value in values: + context._run_id_name = value + if not value: + with pytest.raises(SparkExpectationsMiscException, + match="""The spark expectations context is not set completely, please assign '_run_id_name' .*"""): + context.get_run_id_name + else: + context.get_run_id_name == value + + +def test_get_run_date_name(): + context = SparkExpectationsContext(product_id="product1", spark=spark) + values = [None, 'test'] + for value in values: + context._run_date_name = value + if not value: + with pytest.raises(SparkExpectationsMiscException, + match="""The spark expectations context is not set completely, please assign '_run_date_name' .*"""): + context.get_run_date_name + else: + context.get_run_date_name == value + + +def test_get_run_date_time_name(): + context = SparkExpectationsContext(product_id="product1", spark=spark) + values = [None, 'test'] + for value in values: + context._run_date_time_name = value + if not value: + with pytest.raises(SparkExpectationsMiscException, + match="""The spark expectations context is not set completely, please assign '_run_date_time_name' .*"""): + context.get_run_date_time_name + else: + context.get_run_date_time_name == value def test_set_num_row_dq_rules(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_num_row_dq_rules() assert context._num_row_dq_rules == 1 assert context._num_dq_rules == 1 def test_set_num_agg_dq_rules(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_num_agg_dq_rules(True, True) assert context.get_num_agg_dq_rules == {"num_source_agg_dq_rules": 1, "num_final_agg_dq_rules": 1, @@ -930,24 +962,28 @@ def test_set_num_agg_dq_rules(): def test_set_num_query_dq_rules(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_num_query_dq_rules(True, True) assert context.get_num_query_dq_rules == {"num_source_query_dq_rules": 1, "num_final_query_dq_rules": 1, "num_query_dq_rules": 1} -def test_get_num_row_dq_rules_exception(): - context = SparkExpectationsContext(product_id="product1") - context._num_row_dq_rules = 0.0 +def test_get_num_row_dq_rules(): + context = SparkExpectationsContext(product_id="product1", spark=spark) + with pytest.raises(SparkExpectationsMiscException, - match="""The spark expectations context is not set completely, please assign '_num_row_dq_rules' before - accessing it"""): + match="""The spark expectations context is not set completely, please assign '_num_row_dq_rules' .*"""): + context._num_row_dq_rules = None context.get_num_row_dq_rules + context._num_row_dq_rules = 0 + context.set_num_row_dq_rules() + context.get_num_row_dq_rules == 1 + def test_get_num_agg_dq_rules_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._num_agg_dq_rules = [1, 2, 3, 4] with pytest.raises(SparkExpectationsMiscException, match="""The spark expectations context is not set completely, please assign '_num_agg_dq_rules' before @@ -956,7 +992,7 @@ def test_get_num_agg_dq_rules_exception(): def test_get_num_query_dq_rules_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._num_query_dq_rules = [1, 2, 3, 4] with pytest.raises(SparkExpectationsMiscException, match="""The spark expectations context is not set completely, please assign '_num_query_dq_rules' before @@ -964,56 +1000,71 @@ def test_get_num_query_dq_rules_exception(): context.get_num_query_dq_rules -def test_get_num_dq_rules_exception(): - context = SparkExpectationsContext(product_id="product1") - context._num_dq_rules = 0.0 +def test_get_num_dq_rules(): + context = SparkExpectationsContext(product_id="product1", spark=spark) with pytest.raises(SparkExpectationsMiscException, - match="""The spark expectations context is not set completely, please assign '_num_dq_rules' before - accessing it"""): + match="""The spark expectations context is not set completely, please assign '_num_dq_rules' .*"""): + context._num_dq_rules = None context.get_num_dq_rules + context.reset_num_dq_rules() + context.get_num_dq_rules == 0 + def test_set_summarised_row_dq_res(): - context = SparkExpectationsContext(product_id="product1") - context._summarised_row_dq_res = [{"rule": "rule_1", + context = SparkExpectationsContext(product_id="product1", spark=spark) + context.set_summarised_row_dq_res( [{"rule": "rule_1", "action_if_failed": "ignore", "failed_row_count": 2}, {"rule": "rule_2", "action_if_failed": "fail", - "failed_row_count": 4}] + "failed_row_count": 4}]) + assert context.get_summarised_row_dq_res == [{"rule": "rule_1", "action_if_failed": "ignore", "failed_row_count": 2}, {"rule": "rule_2", "action_if_failed": "fail", "failed_row_count": 4}] +def test_set_target_and_error_table_writer_config(): + context = SparkExpectationsContext(product_id="product1", spark=spark) + context.set_target_and_error_table_writer_config({'format': 'bigquery'}) + + assert context.get_target_and_error_table_writer_config == {'format': 'bigquery'} + + +def test_set_stats_table_writer_config(): + context = SparkExpectationsContext(product_id="product1", spark=spark) + context.set_stats_table_writer_config({'format': 'bigquery'}) + + assert context.get_stats_table_writer_config == {'format': 'bigquery'} -# def test_set_nsp_row_dq_res_topic_name(): +# def test_set_kafka_row_dq_res_topic_name(): # context = SparkExpectationsContext(product_id="product1") -# context.set_nsp_row_dq_res_topic_name("abc") -# assert context.get_nsp_row_dq_res_topic_name == "abc" +# context.set_kafka_row_dq_res_topic_name("abc") +# assert context.get_kafka_row_dq_res_topic_name == "abc" # -# def test_get_nsp_row_dq_res_topic_name_exception(): +# def test_get_kafka_row_dq_res_topic_name_exception(): # context = SparkExpectationsContext(product_id="product1") -# context._nsp_row_dq_res_topic_name = None +# context._kafka_row_dq_res_topic_name = None # with pytest.raises(SparkExpectationsMiscException, -# match= """The spark expectations context is not set completely, please assign '_nsp_row_dq_res_topic_name' before +# match= """The spark expectations context is not set completely, please assign '_kafka_row_dq_res_topic_name' before # accessing it"""): -# context.get_nsp_row_dq_res_topic_name +# context.get_kafka_row_dq_res_topic_name def test_set_source_query_dq_status(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_source_query_dq_status("Passed") assert context.get_source_query_dq_status == "Passed" def test_set_final_query_dq_status(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_final_query_dq_status("Passed") assert context.get_final_query_dq_status == "Passed" def test_get_source_query_dq_status_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._source_query_dq_status = None with pytest.raises(SparkExpectationsMiscException, match="""The spark expectations context is not set completely, please assign '_source_query_dq_status' before @@ -1022,7 +1073,7 @@ def test_get_source_query_dq_status_exception(): def test_get_final_query_dq_status_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._final_query_dq_status = None with pytest.raises(SparkExpectationsMiscException, match="""The spark expectations context is not set completely, please assign '_final_query_dq_status' before @@ -1031,7 +1082,7 @@ def test_get_final_query_dq_status_exception(): def test_set_supported_df_query_dq(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._supported_df_query_dq = context.set_supported_df_query_dq() assert context.get_supported_df_query_dq.collect() == get_spark_session().createDataFrame( [ @@ -1043,7 +1094,7 @@ def test_set_supported_df_query_dq(): def test_get_supported_df_query_dq(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._supported_df_query_dq = None with pytest.raises(SparkExpectationsMiscException, match="""The spark expectations context is not set completely, please assign '_supported_df_query_dq' before @@ -1052,61 +1103,61 @@ def test_get_supported_df_query_dq(): def test_set_debugger_mode(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_debugger_mode(True) assert context._debugger_mode == True def test_get_debugger_mode(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_debugger_mode(True) assert context.get_debugger_mode == True def test_print_dataframe_with_debugger(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_debugger_mode(True) context.print_dataframe_with_debugger(context.set_supported_df_query_dq()) def test_get_error_percentage_negative(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._input_count = 0 assert context.get_error_percentage == 0.0 def test_get_error_percentage_negative(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._input_count = 0 assert context.get_error_percentage == 0.0 def test_get_output_percentage_negative(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._input_count = 0 assert context.get_output_percentage == 0.0 def test_get_success_percentage_negative(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._input_count = 0 assert context.get_success_percentage == 0.0 def test_get_error_drop_percentage_negative(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._input_count = 0 assert context.get_error_drop_percentage == 0.0 def test_reset_num_row_dq_rules(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.reset_num_row_dq_rules() assert context._num_row_dq_rules == 0 def test_reset_num_row_dq_rules(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.reset_num_agg_dq_rules() assert context._num_agg_dq_rules == { "num_agg_dq_rules": 0, @@ -1116,7 +1167,7 @@ def test_reset_num_row_dq_rules(): def test_reset_num_query_dq_rules(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.reset_num_query_dq_rules() assert context._num_query_dq_rules == { "num_query_dq_rules": 0, @@ -1125,35 +1176,48 @@ def test_reset_num_query_dq_rules(): } + +def test_set_end_time_when_dq_job_fails(): + context = SparkExpectationsContext(product_id="product1", spark=spark) + attributes = ['source_agg', 'source_query', 'row', 'final_agg', 'final_query'] + for attribute in attributes: + setattr(context, f'_{attribute}_dq_start_time', datetime.now()) + setattr(context, f'_{attribute}_dq_end_time', None) + context.set_end_time_when_dq_job_fails() + datetime_actual = getattr(context, f'_{attribute}_dq_end_time') + datetime_actual.date == date.today() + + + def test_reset_num_dq_rules(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.reset_num_dq_rules() assert context._num_dq_rules == 0 def test_set_se_streaming_stats_dict(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({"a": "b", "c": "d"}) assert context.get_se_streaming_stats_dict == {"a": "b", "c": "d"} def get_set_se_streaming_stats_dict(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({"a": "b", "c": "d"}) assert context.get_se_streaming_stats_dict == context._se_streaming_stats_dict def test_get_secret_type(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.secret_type: "a"}) assert context.get_secret_type == "a" def test_get_secret_type_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.se_enable_streaming: "a"}) with pytest.raises(SparkExpectationsMiscException, @@ -1164,7 +1228,7 @@ def test_get_secret_type_exception(): def test_get_server_url_key(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.dbx_kafka_server_url: "b", user_config.secret_type: "databricks"}) @@ -1177,7 +1241,7 @@ def test_get_server_url_key(): def test_get_server_url_key_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.dbx_kafka_server_url: "b", user_config.secret_type: "cerberus"}) @@ -1189,7 +1253,7 @@ def test_get_server_url_key_exception(): def test_get_token_endpoint_url(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.dbx_secret_token_url: "d", user_config.secret_type: "databricks"}) @@ -1202,7 +1266,7 @@ def test_get_token_endpoint_url(): def test_get_token_endpoint_url_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.dbx_secret_token_url: "d", user_config.secret_type: "cerberus"}) @@ -1214,7 +1278,7 @@ def test_get_token_endpoint_url_exception(): def test_get_token(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.dbx_secret_token: "g", user_config.secret_type: "databricks"}) @@ -1227,7 +1291,7 @@ def test_get_token(): def test_get_token_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.dbx_secret_token_url: "g", user_config.secret_type: "cerberus"}) @@ -1239,7 +1303,7 @@ def test_get_token_exception(): def test_get_client_id(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.dbx_secret_app_name: "i", user_config.secret_type: "databricks"}) @@ -1252,7 +1316,7 @@ def test_get_client_id(): def test_get_client_id_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.dbx_secret_app_name: "g", user_config.secret_type: "cerberus"}) @@ -1264,7 +1328,7 @@ def test_get_client_id_exception(): def test_get_topic_name(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.dbx_topic_name: "k", user_config.secret_type: "databricks"}) @@ -1277,7 +1341,7 @@ def test_get_topic_name(): def test_get_topic_name_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_dict({user_config.dbx_topic_name: "k", user_config.secret_type: "cerberus"}) @@ -1289,21 +1353,21 @@ def test_get_topic_name_exception(): def test_set_se_streaming_stats_topic_name(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_topic_name("test_topic") assert context.get_se_streaming_stats_topic_name == "test_topic" def test_get_se_streaming_stats_topic_name(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_topic_name("test_topic") assert context.get_se_streaming_stats_topic_name == context.get_se_streaming_stats_topic_name def test_get_se_streaming_stats_topic_name_exception(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_se_streaming_stats_topic_name("") with pytest.raises(SparkExpectationsMiscException, @@ -1313,7 +1377,7 @@ def test_get_se_streaming_stats_topic_name_exception(): context.get_se_streaming_stats_topic_name def test_set_rules_exceeds_threshold(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context.set_rules_exceeds_threshold([ { "rule_name": 'rule_1', @@ -1334,7 +1398,7 @@ def test_set_rules_exceeds_threshold(): }] def test_get_rules_exceds_threshold(): - context = SparkExpectationsContext(product_id="product1") + context = SparkExpectationsContext(product_id="product1", spark=spark) context._rules_error_per=[ { "rule_name": 'rule_1', diff --git a/tests/core/test_expectations.py b/tests/core/test_expectations.py index e3bae4e8..d74efe6a 100644 --- a/tests/core/test_expectations.py +++ b/tests/core/test_expectations.py @@ -1,22 +1,22 @@ # pylint: disable=too-many-lines import os +import datetime from unittest.mock import Mock from unittest.mock import patch import pytest from pyspark.sql import DataFrame from pyspark.sql.functions import lit, to_timestamp, col from pyspark.sql.types import StringType, IntegerType, StructField, StructType + from spark_expectations.core.context import SparkExpectationsContext -from spark_expectations.utils.reader import SparkExpectationsReader -from spark_expectations.sinks.utils.writer import SparkExpectationsWriter -from spark_expectations.core.expectations import SparkExpectations +from spark_expectations.core.expectations import SparkExpectations, WrappedDataFrameWriter from spark_expectations.config.user_config import Constants as user_config from spark_expectations.core import get_spark_session from spark_expectations.core.exceptions import ( SparkExpectationsMiscException ) -from spark_expectations.notifications.push.spark_expectations_notify import SparkExpectationsNotify -from spark_expectations.sinks.utils.collect_statistics import SparkExpectationsCollectStatistics + +# os.environ["UNIT_TESTING_ENV"] = "local" spark = get_spark_session() @@ -65,27 +65,25 @@ def fixture_dq_rules(): "num_final_agg_dq_rules": 0}} -@pytest.fixture(name="_fixture_expectations") -def fixture_expectations(): - # create a sample input expectations to run on raw dataframe - return { # expectations rules - "row_dq_rules": [{ - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col1_threshold", - "column_name": "col1", - "expectation": "col1 > 1", - "action_if_failed": "ignore", - "tag": "validity", - "description": "col1 value must be greater than 1", - "enable_error_drop_alert": True, - "error_drop_threshold": "10", - }], - "agg_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - +@pytest.fixture(name="_fixture_rules_df") +def fixture_rules_df(): + rules_dict = { + "product_id": "product1", + "table_name": "dq_spark.test_table", + "rule_type": "row_dq", + "rule": "col1_threshold", + "column_name": "col1", + "expectation": "col1 > 1", + "action_if_failed": "ignore", + "tag": "validity", + "description": "col1 value must be greater than 1", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "10" } + return spark.createDataFrame([rules_dict]) @pytest.fixture(name="_fixture_create_database") @@ -103,7 +101,7 @@ def fixture_create_database(): @pytest.fixture(name="_fixture_context") def fixture_context(): - _context: SparkExpectationsContext = SparkExpectationsContext("product_id") + _context: SparkExpectationsContext = SparkExpectationsContext("product_id", spark) _context.set_table_name("dq_spark.test_final_table") _context.set_dq_stats_table_name("dq_spark.test_dq_stats_table") _context.set_final_table_name("dq_spark.test_final_table") @@ -119,81 +117,80 @@ def fixture_context(): @pytest.fixture(name="_fixture_spark_expectations") -def fixture_spark_expectations(_fixture_context): +def fixture_spark_expectations(_fixture_rules_df): # create a spark expectations class object - spark_expectations = SparkExpectations("product1") + writer = WrappedDataFrameWriter().mode("append").format("delta") + spark_expectations = SparkExpectations(product_id="product1", + rules_df=_fixture_rules_df, + stats_table="dq_spark.test_dq_stats_table", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + ) def _error_threshold_exceeds(expectations): pass - spark_expectations._context = _fixture_context - spark_expectations.reader = SparkExpectationsReader("product1", _fixture_context) - spark_expectations._writer = SparkExpectationsWriter("product1", _fixture_context) - spark_expectations._notification = SparkExpectationsNotify("product1", _fixture_context) - spark_expectations._notification.notify_rules_exceeds_threshold = _error_threshold_exceeds - spark_expectations._statistics_decorator = SparkExpectationsCollectStatistics("product1", - spark_expectations._context, - spark_expectations._writer) + # spark_expectations.reader = SparkExpectationsReader(spark_expectations._context) + # spark_expectations._writer = SparkExpectationsWriter(spark_expectations._context) + # spark_expectations._notification = SparkExpectationsNotify( spark_expectations._context) + # spark_expectations._notification.notify_rules_exceeds_threshold = _error_threshold_exceeds + # spark_expectations._statistics_decorator = SparkExpectationsCollectStatistics(spark_expectations._context, + # spark_expectations._writer) return spark_expectations -@pytest.fixture(name="_fixture_create_stats_table") -def fixture_create_stats_table(): - # drop if exist dq_spark database and create with test_dq_stats_table - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db") - spark.sql("create database if not exists dq_spark") - spark.sql("use dq_spark") - - spark.sql("drop table if exists test_dq_stats_table") - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db/test_dq_stats_table") - - spark.sql( - """ - create table test_dq_stats_table ( - product_id STRING, - table_name STRING, - input_count LONG, - error_count LONG, - output_count LONG, - output_percentage FLOAT, - success_percentage FLOAT, - error_percentage FLOAT, - source_agg_dq_results array>, - final_agg_dq_results array>, - source_query_dq_results array>, - final_query_dq_results array>, - row_dq_res_summary array>, - row_dq_error_threshold array>, - dq_status map, - dq_run_time map, - dq_rules map>, - meta_dq_run_id STRING, - meta_dq_run_date DATE, - meta_dq_run_datetime TIMESTAMP - ) - USING delta - """ - ) - - yield "test_dq_stats_table" - - spark.sql("drop table if exists test_dq_stats_table") - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db/test_dq_stats_table") - - # remove database - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db") +# @pytest.fixture(name="_fixture_create_stats_table") +# def fixture_create_stats_table(): +# # drop if exist dq_spark database and create with test_dq_stats_table +# os.system("rm -rf /tmp/hive/warehouse/dq_spark.db") +# spark.sql("create database if not exists dq_spark") +# spark.sql("use dq_spark") +# +# spark.sql("drop table if exists test_dq_stats_table") +# os.system("rm -rf /tmp/hive/warehouse/dq_spark.db/test_dq_stats_table") +# +# spark.sql( +# """ +# create table test_dq_stats_table ( +# product_id STRING, +# table_name STRING, +# input_count LONG, +# error_count LONG, +# output_count LONG, +# output_percentage FLOAT, +# success_percentage FLOAT, +# error_percentage FLOAT, +# source_agg_dq_results array>, +# final_agg_dq_results array>, +# source_query_dq_results array>, +# final_query_dq_results array>, +# row_dq_res_summary array>, +# row_dq_error_threshold array>, +# dq_status map, +# dq_run_time map, +# dq_rules map>, +# meta_dq_run_id STRING, +# meta_dq_run_date DATE, +# meta_dq_run_datetime TIMESTAMP +# ) +# USING delta +# """ +# ) +# +# yield "test_dq_stats_table" +# +# spark.sql("drop table if exists test_dq_stats_table") +# os.system("rm -rf /tmp/hive/warehouse/dq_spark.db/test_dq_stats_table") +# +# # remove database +# os.system("rm -rf /tmp/hive/warehouse/dq_spark.db") @pytest.mark.parametrize("input_df, " "expectations, " "write_to_table, " "write_to_temp_table, " - "row_dq, agg_dq, " - "source_agg_dq, " - "final_agg_dq, " - "query_dq, " - "source_query_dq, " - "final_query_dq, " "expected_output, " "input_count, " "error_count, " @@ -207,7 +204,7 @@ def fixture_create_stats_table(): [ ( # Note: where err: refers error table and fnl: final table - # test case 1 + # test case 0 # In this test case, the action for failed rows is "ignore", # so the function should return the input DataFrame with all rows. # collect stats in the test_stats_table and @@ -223,10 +220,10 @@ def fixture_create_stats_table(): # row meets expectations(ignore), log into final table ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col1_threshold", "column_name": "col1", @@ -234,22 +231,15 @@ def fixture_create_stats_table(): "action_if_failed": "ignore", "tag": "validity", "description": "col1 value must be greater than 1", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, "enable_error_drop_alert": True, - "error_drop_threshold": "10", - }], - "agg_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - - }, + "error_drop_threshold": "10" + } + ], True, # write to table True, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq # expected res spark.createDataFrame( [ @@ -277,7 +267,7 @@ def fixture_create_stats_table(): ), ( - # test case 2 + # test case 1 # In this test case, the action for failed rows is "drop", # collect stats in the test_stats_table and # log the error records into the error table. @@ -291,10 +281,10 @@ def fixture_create_stats_table(): # row meets expectations(drop), log into final table ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col2_set", "column_name": "col2", @@ -302,22 +292,15 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "strict", "description": "col2 value must be in ('a', 'b')", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "5", - }], - "agg_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - - }, + } + ], True, # write to table True, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq # expected res spark.createDataFrame( [ @@ -345,7 +328,7 @@ def fixture_create_stats_table(): "source_query_dq_status": "Skipped", "final_query_dq_status": "Skipped"}, ), ( - # test case 3 + # test case 2 # In this test case, the action for failed rows is "fail", # spark expectations expected to fail # collect stats in the test_stats_table and @@ -360,10 +343,10 @@ def fixture_create_stats_table(): # row meets doesn't expectations(fail), log into final table ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_threshold", "column_name": "col3", @@ -371,22 +354,15 @@ def fixture_create_stats_table(): "action_if_failed": "fail", "tag": "strict", "description": "col3 value must be greater than 6", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "15", - }], - "agg_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - - }, + } + ], True, # write to table True, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq SparkExpectationsMiscException, # expected res 3, # input count 3, # error count @@ -405,7 +381,7 @@ def fixture_create_stats_table(): "final_agg_dq_status": "Skipped", "run_status": "Failed", "source_query_dq_status": "Skipped", "final_query_dq_status": "Skipped"}, ), - ( # test case 4 + ( # test case 3 # In this test case, the action for failed rows is "ignore" & "drop", # collect stats in the test_stats_table and # log the error records into the error table. @@ -419,10 +395,10 @@ def fixture_create_stats_table(): # row doesnt'meets expectations1(ignore), log into final table ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_threshold", "column_name": "col3", @@ -430,36 +406,32 @@ def fixture_create_stats_table(): "action_if_failed": "ignore", "tag": "strict", "description": "col3 value must be greater than 6", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": False, "error_drop_threshold": "10", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col1_add_col3_threshold", - "column_name": "col1", - "expectation": "(col1+col3) > 6", - "action_if_failed": "drop", - "tag": "strict", - "description": "col1_add_col3 value must be greater than 6", - "enable_error_drop_alert": False, - "error_drop_threshold": "15", - } - ], - "agg_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "row_dq", + "rule": "col1_add_col3_threshold", + "column_name": "col1", + "expectation": "(col1+col3) > 6", + "action_if_failed": "drop", + "tag": "strict", + "description": "col1_add_col3 value must be greater than 6", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "15", + } + ], - }, True, # write to table True, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq # expected res spark.createDataFrame( [ @@ -485,7 +457,7 @@ def fixture_create_stats_table(): "source_query_dq_status": "Skipped", "final_query_dq_status": "Skipped"}, ), ( - # test case 5 + # test case 4 # In this test case, the action for failed rows is "ignore" & "fail", # collect stats in the test_stats_table and # log the error records into the error table. @@ -499,10 +471,10 @@ def fixture_create_stats_table(): # row meets expectations1(ignore), log into final table ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_threshold", "column_name": "col3", @@ -510,36 +482,31 @@ def fixture_create_stats_table(): "action_if_failed": "ignore", "tag": "strict", "description": "col3 value must be greater than 6", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "20", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col3_minus_col1_threshold", - "column_name": "col1", - "expectation": "(col3-col1) > 1", - "action_if_failed": "fail", - "tag": "strict", - "description": "col3_minus_col1 value must be greater than 1", - "enable_error_drop_alert": True, - "error_drop_threshold": "5", - } - ], - "agg_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - - }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "row_dq", + "rule": "col3_minus_col1_threshold", + "column_name": "col1", + "expectation": "(col3-col1) > 1", + "action_if_failed": "fail", + "tag": "strict", + "description": "col3_minus_col1 value must be greater than 1", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "5", + } + ], True, # write to table True, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq # expected res spark.createDataFrame( [ @@ -565,7 +532,8 @@ def fixture_create_stats_table(): "final_agg_dq_status": "Skipped", "run_status": "Passed", "source_query_dq_status": "Skipped", "final_query_dq_status": "Skipped"}, ), - ( # Test case 6 + ( + # Test case 5 # In this test case, the action for failed rows is "drop" & "fail", # collect stats in the test_stats_table and # log the error records into the error table. @@ -579,10 +547,10 @@ def fixture_create_stats_table(): # row meets expectations1(drop), & 2(fail) log into final table ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_threshold", "column_name": "col3", @@ -590,36 +558,31 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "strict", "description": "col3 value must be greater than 6", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "25", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col3_minus_col1_threshold", - "column_name": "col1", - "expectation": "(col3-col1) = 1", - "action_if_failed": "fail", - "tag": "strict", - "description": "col3_minus_col1 value must be equals to 1", - "enable_error_drop_alert": False, - "error_drop_threshold": "25", - } - ], - "agg_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - - }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "row_dq", + "rule": "col3_minus_col1_threshold", + "column_name": "col1", + "expectation": "(col3-col1) = 1", + "action_if_failed": "fail", + "tag": "strict", + "description": "col3_minus_col1 value must be equals to 1", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "25", + } + ], True, # write to table True, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_dq - False, # final_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq SparkExpectationsMiscException, # expected res 3, # input count 3, # error count @@ -652,10 +615,10 @@ def fixture_create_stats_table(): # row meets expectations1(drop) & meets 2(fail), log into final table ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_threshold", "column_name": "col3", @@ -663,36 +626,31 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "strict", "description": "col3 value must be greater than 6", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "10", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col3_mul_col1_threshold", - "column_name": "col1", - "expectation": "(col3*col1) > 1", - "action_if_failed": "fail", - "tag": "strict", - "description": "col3_mul_col1 value must be equals to 1", - "enable_error_drop_alert": True, - "error_drop_threshold": "10", - } - ], - "agg_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - - }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "row_dq", + "rule": "col3_mul_col1_threshold", + "column_name": "col1", + "expectation": "(col3*col1) > 1", + "action_if_failed": "fail", + "tag": "strict", + "description": "col3_mul_col1 value must be equals to 1", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "10", + } + ], True, # write to table True, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_dq - False, # final_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq # expected res spark.createDataFrame([], schema=StructType([ StructField("col1", IntegerType()), @@ -732,10 +690,10 @@ def fixture_create_stats_table(): # row meets all the expectations ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col1_threshold", "column_name": "col1", @@ -743,11 +701,15 @@ def fixture_create_stats_table(): "action_if_failed": "ignore", "tag": "strict", "description": "col1 value must be greater than 1", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": False, "error_drop_threshold": "0", - }, { + }, + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_threshold", "column_name": "col3", @@ -755,36 +717,31 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "strict", "description": "col3 value must be greater than 5", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "10", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col3_mul_col1_threshold", - "column_name": "col1", - "expectation": "(col3*col1) > 1", - "action_if_failed": "fail", - "tag": "strict", - "description": "col3_mul_col1 value must be equals to 1", - "enable_error_drop_alert": False, - "error_drop_threshold": "20", - } - ], - "agg_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - - }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "row_dq", + "rule": "col3_mul_col1_threshold", + "column_name": "col1", + "expectation": "(col3*col1) > 1", + "action_if_failed": "fail", + "tag": "strict", + "description": "col3_mul_col1 value must be equals to 1", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], True, # write to table True, # write_to_temp_table - True, # row_dq - False, # agg_dq - False, # source_dq - False, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq spark.createDataFrame( # expected output [ {"col1": 3, "col2": "c", 'col3': 6}, @@ -817,36 +774,33 @@ def fixture_create_stats_table(): {"col1": 3, "col2": "c", 'col3': 6}, ] ), - { # expectations rules - "agg_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "agg_dq", "rule": "sum_col3_threshold", "column_name": "col3", "expectation": "sum(col3) > 20", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, "action_if_failed": "ignore", "tag": "strict", "description": "sum col3 value must be greater than 20", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "10", - }], - "row_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - - }, + } + ], True, # write to table True, # write to temp table - False, # row_dq - True, # agg_dq - True, # source_agg_dq - False, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq - None, # expected result + spark.createDataFrame( + [ + {"col1": 1, "col2": "a", "col3": 4}, + {"col1": 2, "col2": "b", "col3": 5}, + {"col1": 3, "col2": "c", 'col3': 6}, + ] + ), # expected result 3, # input count 0, # error count 0, # output count @@ -877,33 +831,26 @@ def fixture_create_stats_table(): {"col1": 3, "col2": "c", 'col3': 6}, ] ), - { # expectations rules - "agg_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "agg_dq", "rule": "avg_col3_threshold", "column_name": "col3", "expectation": "avg(col3) > 25", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, "action_if_failed": "fail", "tag": "strict", "description": "avg col3 value must be greater than 25", - }], - "row_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - - }, + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], True, # write to table True, # write to temp table - False, # row_dq - True, # agg_dq - True, # source_agg_dq - False, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq SparkExpectationsMiscException, # excepted result 3, # input count 0, # error count @@ -935,23 +882,26 @@ def fixture_create_stats_table(): {"col1": 3, "col2": "c", 'col3': 6}, ] ), - { # expectations rules - "agg_dq_rules": [{ - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "agg_dq", - "rule": "min_col1_threshold", - "column_name": "col1", - "expectation": "min(col1) > 10", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, - "action_if_failed": "ignore", - "tag": "strict", - "description": "min col1 value must be greater than 10", - }], - "row_dq_rules": [{ + [{ + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "agg_dq", + "rule": "min_col1_threshold", + "column_name": "col1", + "expectation": "min(col1) > 10", + "action_if_failed": "ignore", + "tag": "strict", + "description": "min col1 value must be greater than 10", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + + }, + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col2_set", "column_name": "col2", @@ -959,21 +909,15 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "strict", "description": "col2 value must be in ('a', 'b')", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": False, "error_drop_threshold": "0", - }], - "target_table_name": "dq_spark.test_final_table" - - }, + } + ], True, # write to table True, # write to temp table - True, # row_dq - True, # agg_dq - False, # source_agg_dq - True, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq spark.createDataFrame( [ {"col1": 1, "col2": "a", "col3": 4}, @@ -1018,23 +962,26 @@ def fixture_create_stats_table(): ] ), - { # expectations rules - "agg_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "agg_dq", "rule": "std_col3_threshold", "column_name": "col3", "expectation": "stddev(col3) > 10", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, "action_if_failed": "fail", "tag": "strict", "description": "std col3 value must be greater than 10", - }], - "row_dq_rules": [{ + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + }, + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col2_set", "column_name": "col2", @@ -1042,20 +989,15 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "strict", "description": "col2 value must be in ('a', 'b')", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "15", - }], - "target_table_name": "dq_spark.test_final_table" - }, + } + ], True, # write to table True, # write temp table - True, # row_dq - True, # agg_dq - False, # source_dq_dq - True, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq SparkExpectationsMiscException, # expected result 3, # input count 1, # error count @@ -1093,10 +1035,10 @@ def fixture_create_stats_table(): # row meets the expectations ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col1_threshold", "column_name": "col1", @@ -1104,20 +1046,15 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "validity", "description": "col1 value must be greater than 1", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": False, "error_drop_threshold": "0", - }], - "target_table_name": "dq_spark.test_final_table" - }, + } + ], True, # write to table True, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq spark.createDataFrame([ # expected_output {"col1": 2, "col2": "b"}, {"col1": 3, "col2": "c"} @@ -1155,10 +1092,10 @@ def fixture_create_stats_table(): # row meets all row_dq expectations ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col1_threshold_1", "column_name": "col1", @@ -1166,47 +1103,47 @@ def fixture_create_stats_table(): "action_if_failed": "ignore", "tag": "validity", "description": "col1 value must be greater than 1", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "10", + }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "row_dq", + "rule": "col2_set", + "column_name": "col2", + "expectation": "col2 in ('a', 'b', 'c')", + "action_if_failed": "drop", + "tag": "validity", + "description": "col1 value must be greater than 2", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "10", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col2_set", - "column_name": "col2", - "expectation": "col2 in ('a', 'b', 'c')", - "action_if_failed": "drop", - "tag": "validity", - "description": "col1 value must be greater than 2", - "enable_error_drop_alert": True, - "error_drop_threshold": "10", - }], - "agg_dq_rules": [{ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "agg_dq", "rule": "distinct_col2_threshold", "column_name": "col2", "expectation": "count(distinct col2) > 4", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, "action_if_failed": "ignore", "tag": "validity", "description": "distinct of col2 value must be greater than 4", - }], - "target_table_name": "dq_spark.test_final_table" - - }, + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], True, # write to table True, # write to temp table - True, # row_dq - True, # agg_dq - True, # source_agg_dq - False, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq spark.createDataFrame([ # expected_output {"col1": 1, "col2": "a", "col3": 4}, {"col1": 2, "col2": "b", "col3": 5}, @@ -1230,7 +1167,6 @@ def fixture_create_stats_table(): {"row_dq_status": "Passed", "source_agg_dq_status": "Passed", "final_agg_dq_status": "Skipped", "run_status": "Passed", "source_query_dq_status": "Skipped", "final_query_dq_status": "Skipped"} # status - ), ( @@ -1249,10 +1185,10 @@ def fixture_create_stats_table(): # row meets all row-dq expectations ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_threshold_4", "column_name": "col3", @@ -1260,47 +1196,47 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "validity", "description": "col3 value must be greater than 4", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "20", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col2_set", - "column_name": "col2", - "expectation": "col2 in ('a', 'b')", - "action_if_failed": "ignore", - "tag": "validity", - "description": "col2 value must be in (a, b)", - "enable_error_drop_alert": True, - "error_drop_threshold": "2", - }], - "agg_dq_rules": [{ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "agg_dq", + "table_name": "dq_spark.test_final_table", + "rule_type": "row_dq", + "rule": "col2_set", + "column_name": "col2", + "expectation": "col2 in ('a', 'b')", + "action_if_failed": "ignore", + "tag": "validity", + "description": "col2 value must be in (a, b)", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "2", + }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "agg_dq", "rule": "avg_col1_threshold", "column_name": "col1", "expectation": "avg(col1) > 4", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, "action_if_failed": "ignore", "tag": "accuracy", "description": "avg of col1 value must be greater than 4", - }], - "target_table_name": "dq_spark.test_final_table" - - }, + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], True, # write to table True, # write to temp table - True, # row_dq - True, # agg_dq - True, # source_agg_dq - True, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq spark.createDataFrame([ # expected_output {"col1": 2, "col2": "b", "col3": 5}, {"col1": 3, "col2": "c", "col3": 6} @@ -1326,7 +1262,6 @@ def fixture_create_stats_table(): {"row_dq_status": "Passed", "source_agg_dq_status": "Passed", "final_agg_dq_status": "Passed", "run_status": "Passed", "source_query_dq_status": "Skipped", "final_query_dq_status": "Skipped"} # status - ), ( # Test case 15 @@ -1346,10 +1281,10 @@ def fixture_create_stats_table(): {"col1": 2, "col2": "d", "col3": 7} ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_and_col1_threshold_4", "column_name": "col3, col1", @@ -1357,74 +1292,63 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "validity", "description": "col3 and col1 operation value must be greater than 3", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "25", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col2_set", - "column_name": "col2", - "expectation": "col2 in ('b', 'c')", - "action_if_failed": "ignore", - "tag": "validity", - "description": "col2 value must be in (b, c)", - "enable_error_drop_alert": True, - "error_drop_threshold": "30", - }], - "agg_dq_rules": [{ + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "row_dq", + "rule": "col2_set", + "column_name": "col2", + "expectation": "col2 in ('b', 'c')", + "action_if_failed": "ignore", + "tag": "validity", + "description": "col2 value must be in (b, c)", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "30", + }, + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "agg_dq", "rule": "avg_col1_threshold", "column_name": "col1", "expectation": "avg(col1) > 4", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, "action_if_failed": "ignore", "tag": "validity", "description": "avg of col1 value must be greater than 4", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "agg_dq", - "rule": "stddev_col3_threshold", - "column_name": "col3", - "expectation": "stddev(col3) > 1", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": False, - "action_if_failed": "fail", - "tag": "validity", - "description": "stddev of col3 value must be greater than one" - }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "agg_dq", - "rule": "stddev_col3_threshold", - "column_name": "col3", - "expectation": "stddev(col3) < 1", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "action_if_failed": "fail", - "tag": "validity", - "description": "avg of col3 value must be greater than 0", - } - ], - "target_table_name": "dq_spark.test_final_table" - - }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "agg_dq", + "rule": "stddev_col3_threshold", + "column_name": "col3", + "expectation": "stddev(col3) > 1", + "action_if_failed": "ignore", + "tag": "validity", + "description": "stddev of col3 value must be greater than one", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], True, # write to table True, # write to temp table - True, # row_dq - True, # agg_dq - True, # source_agg_dq - True, # final_agg_dq - False, # query_dq - False, # source_query_dq - False, # final_query_dq spark.createDataFrame([ # expected_output {"col1": 3, "col2": "c", "col3": 6}, {"col1": 2, "col2": "d", "col3": 7} @@ -1436,9 +1360,16 @@ def fixture_create_stats_table(): "rule": "avg_col1_threshold", "rule_type": "agg_dq", "action_if_failed": "ignore", "tag": "validity"}], # source_agg_result - [{"description": "avg of col1 value must be greater than 4", - "rule": "avg_col1_threshold", - "rule_type": "agg_dq", "action_if_failed": "ignore", "tag": "validity"}], + [{'action_if_failed': 'ignore', + 'description': 'avg of col1 value must be greater than 4', + 'rule': 'avg_col1_threshold', + 'rule_type': 'agg_dq', + 'tag': 'validity'}, + {'action_if_failed': 'ignore', + 'description': 'stddev of col3 value must be greater than one', + 'rule': 'stddev_col3_threshold', + 'rule_type': 'agg_dq', + 'tag': 'validity'}], # final_agg_result None, # source_query_dq_res None, # final_query_dq_res @@ -1450,48 +1381,6 @@ def fixture_create_stats_table(): {"row_dq_status": "Passed", "source_agg_dq_status": "Passed", "final_agg_dq_status": "Passed", "run_status": "Passed", "source_query_dq_status": "Skipped", "final_query_dq_status": "Skipped"} # status - - ), - ( - spark.createDataFrame( - [ - {"col1": 1, "col2": "a", "col3": 4}, - # row doesn't meet expectations(fail),log into error and raise error - {"col1": 2, "col2": "b", "col3": 5}, # row meet expectations(fail) - {"col1": 3, "col2": "c", "col3": 6}, # row meet expectations(fail) - ] - ), - { # expectations rules - "row_dq_rules": [], - "agg_dq_rules": [], - "target_table_name": "dq_spark.test_final_table" - - }, - True, - True, - True, - True, - True, - True, - False, # query_dq - False, # source_query_dq - False, # final_query_dq - SparkExpectationsMiscException, - 3, # input count - 0, # error count - 0, - None, - None, - None, # source_query_dq_res - None, # final_query_dq_res - {"rules": {"num_dq_rules": 0, "num_row_dq_rules": 0}, - "query_dq_rules": {"num_final_query_dq_rules": 0, "num_source_query_dq_rules": 0, - "num_query_dq_rules": 0}, - "agg_dq_rules": {"num_source_agg_dq_rules": 0, "num_agg_dq_rules": 0, - "num_final_agg_dq_rules": 0}}, # dq rules - {"row_dq_status": "Skipped", "source_agg_dq_status": "Failed", - "final_agg_dq_status": "Skipped", "run_status": "Failed", - "source_query_dq_status": "Skipped", "final_query_dq_status": "Skipped"} ), ( # Test case 16 @@ -1510,46 +1399,42 @@ def fixture_create_stats_table(): # row meets all row_dq_expectations ] ), - { # expectations rules - "query_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "query_dq", "rule": "sum_col1_threshold", "column_name": "col1", "expectation": "(select sum(col1) from test_table) > 10", + "action_if_failed": "ignore", + "tag": "validity", + "description": "sum of col1 value must be greater than 10", "enable_for_source_dq_validation": True, "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "stddev_col3_threshold", + "column_name": "col3", + "expectation": "(select stddev(col3) from test_table) > 0", "action_if_failed": "ignore", "tag": "validity", - "description": "sum of col1 value must be greater than 10" - }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "query_dq", - "rule": "stddev_col3_threshold", - "column_name": "col3", - "expectation": "(select stddev(col3) from test_table) > 0", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, - "action_if_failed": "ignore", - "tag": "validity", - "description": "stddev of col3 value must be greater than 0" - } - ], - "target_table_name": "dq_spark.test_final_table" - - }, + "description": "stddev of col3 value must be greater than 0", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], False, # write to table False, # write to temp table - False, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - True, # query_dq - True, # source_query_dq - False, # final_query_dq None, # expected result 3, # input count 0, # error count @@ -1570,7 +1455,6 @@ def fixture_create_stats_table(): {"row_dq_status": "Skipped", "source_agg_dq_status": "Skipped", "final_agg_dq_status": "Skipped", "run_status": "Passed", "source_query_dq_status": "Passed", "final_query_dq_status": "Skipped"} # status - ), ( @@ -1590,10 +1474,10 @@ def fixture_create_stats_table(): # row meets all row_dq_expectations ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_and_col1_threshold_4", "column_name": "col3, col1", @@ -1601,49 +1485,47 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "validity", "description": "col3 and col1 operation value must be greater than 3", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "20", - }], - "query_dq_rules": [{ + }, + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "query_dq", "rule": "max_col1_threshold", "column_name": "col1", - "expectation": "(select max(col1) from target_test_table) > 10", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, + "expectation": "(select max(col1) from test_final_table_view) > 10", "action_if_failed": "ignore", "tag": "strict", "description": "max of col1 value must be greater than 10", - + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "query_dq", - "rule": "min_col3_threshold", - "column_name": "col3", - "expectation": "(select min(col3) from target_test_table) > 0", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "action_if_failed": "ignore", - "tag": "validity", - "description": "min of col3 value must be greater than 0" - } - ], - "target_table_name": "dq_spark.test_final_table" - - }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "min_col3_threshold", + "column_name": "col3", + "expectation": "(select min(col3) from test_final_table_view) > 0", + "action_if_failed": "ignore", + "tag": "validity", + "description": "min of col3 value must be greater than 0", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], True, # write to table False, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - True, # query_dq - False, # source_query_dq - True, # final_query_dq spark.createDataFrame( [ {"col1": 3, "col2": "c", "col3": 6}, @@ -1667,7 +1549,6 @@ def fixture_create_stats_table(): {"row_dq_status": "Passed", "source_agg_dq_status": "Skipped", "final_agg_dq_status": "Skipped", "run_status": "Passed", "source_query_dq_status": "Skipped", "final_query_dq_status": "Passed"} # status - ), ( # Test case 18 @@ -1686,57 +1567,58 @@ def fixture_create_stats_table(): # row meets all row_dq_expectations ] ), - { # expectations rules - "query_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "query_dq", "rule": "min_col1_threshold", "column_name": "col1", - "expectation": "(select min(col1) from test_table) > 10", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": False, + "expectation": "(select min(col1) from test_final_table_view) > 10", "action_if_failed": "fail", "tag": "validity", - "description": "min of col1 value must be greater than 10" + "description": "min of col1 value must be greater than 10", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "query_dq", - "rule": "stddev_col3_threshold", - "column_name": "col3", - "expectation": "(select stddev(col3) from test_table) > 0", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": False, - "action_if_failed": "ignore", - "tag": "validity", - "description": "stddev of col3 value must be greater than 0" - } - ], - "target_table_name": "dq_spark.test_final_table" - - }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "stddev_col3_threshold", + "column_name": "col3", + "expectation": "(select stddev(col3) from test_final_table_view) > 0", + "action_if_failed": "ignore", + "tag": "validity", + "description": "stddev of col3 value must be greater than 0", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], False, # write to table False, # write to temp table - False, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - True, # query_dq - True, # source_query_dq - False, # final_query_dq SparkExpectationsMiscException, # expected result 3, # input count 0, # error count 0, # output count None, # source_agg_result None, # final_agg_result - # final_agg_result - [{"description": "min of col1 value must be greater than 10", - "rule": "min_col1_threshold", - "rule_type": "query_dq", "action_if_failed": "fail", "tag": "validity"}], - # source_query_dq_res + [{'action_if_failed': 'fail', + 'description': 'min of col1 value must be greater than 10', + 'rule': 'min_col1_threshold', + 'rule_type': 'query_dq', + 'tag': 'validity'}, + {'action_if_failed': 'ignore', + 'description': 'stddev of col3 value must be greater than 0', + 'rule': 'stddev_col3_threshold', + 'rule_type': 'query_dq', + 'tag': 'validity'}], # source_query_dq_res None, # final_query_dq_res {"rules": {"num_dq_rules": 2, "num_row_dq_rules": 0}, "query_dq_rules": {"num_final_query_dq_rules": 2, "num_source_query_dq_rules": 2, @@ -1746,9 +1628,7 @@ def fixture_create_stats_table(): {"row_dq_status": "Skipped", "source_agg_dq_status": "Skipped", "final_agg_dq_status": "Skipped", "run_status": "Failed", "source_query_dq_status": "Failed", "final_query_dq_status": "Skipped"} # status - ), - ( # Test case 19 # In this test case, dq run set for query_dq final_query_dq(ignore, fail) @@ -1766,10 +1646,10 @@ def fixture_create_stats_table(): # row meets all row_dq_expectations ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_and_col1_threshold_4", "column_name": "col3, col1", @@ -1777,48 +1657,47 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "validity", "description": "col3 and col1 operation value must be greater than 3", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "25", - }], - "query_dq_rules": [{ + }, + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "query_dq", "rule": "max_col1_threshold", "column_name": "col1", - "expectation": "(select max(col1) from target_test_table) > 10", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, + "expectation": "(select max(col1) from test_final_table_view) > 10", "action_if_failed": "fail", "tag": "strict", - "description": "max of col1 value must be greater than 10" + "description": "max of col1 value must be greater than 10", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "query_dq", - "rule": "min_col3_threshold", - "column_name": "col3", - "expectation": "(select min(col3) from target_test_table) > 0", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "action_if_failed": "ignore", - "tag": "validity", - "description": "min of col3 value must be greater than 0" - } - ], - "target_table_name": "dq_spark.test_final_table" - - }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "min_col3_threshold", + "column_name": "col3", + "expectation": "(select min(col3) from test_final_table_view) > 0", + "action_if_failed": "ignore", + "tag": "validity", + "description": "min of col3 value must be greater than 0", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], True, # write to table False, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - True, # query_dq - False, # source_query_dq - True, # final_query_dq SparkExpectationsMiscException, # expected result 3, # input count 2, # error count @@ -1839,11 +1718,9 @@ def fixture_create_stats_table(): {"row_dq_status": "Passed", "source_agg_dq_status": "Skipped", "final_agg_dq_status": "Skipped", "run_status": "Failed", "source_query_dq_status": "Skipped", "final_query_dq_status": "Failed"} # status - ), - ( - # Test case 19 + # Test case 20 # In this test case, dq run set for query_dq source_query_dq & # final_query_dq(ignore, fail) # with action_if_failed (ignore, fail) for query_dq @@ -1861,10 +1738,10 @@ def fixture_create_stats_table(): # row meets all row_dq_expectations ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_mod_2", "column_name": "col3", @@ -1872,61 +1749,63 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "validity", "description": "col3 mod must equals to 0", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": True, "error_drop_threshold": "40", - }], - "query_dq_rules": [{ + }, + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "query_dq", "rule": "min_col1_threshold", "column_name": "col1", - "expectation": "(select min(col1) from test_table) > 10", + "expectation": "(select min(col1) from test_final_table_view) > 10", + "action_if_failed": "ignore", + "tag": "strict", + "description": "min of col1 value must be greater than 10", "enable_for_source_dq_validation": True, "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "max_col1_threshold", + "column_name": "col1", + "expectation": "(select max(col1) from test_final_table_view) > 100", "action_if_failed": "ignore", "tag": "strict", - "description": "min of col1 value must be greater than 10" + "description": "max of col1 value must be greater than 100", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "query_dq", - "rule": "max_col1_threshold", - "column_name": "col1", - "expectation": "(select max(col1) from target_test_table) > 100", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "action_if_failed": "ignore", - "tag": "strict", - "description": "max of col1 value must be greater than 100" - }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "query_dq", - "rule": "min_col3_threshold", - "column_name": "col3", - "expectation": "(select min(col3) from target_test_table) > 0", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "action_if_failed": "fail", - "tag": "validity", - "description": "min of col3 value must be greater than 0" - } - ], - "target_table_name": "dq_spark.test_final_table" - - }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "min_col3_threshold", + "column_name": "col3", + "expectation": "(select min(col3) from test_final_table_view) > 0", + "action_if_failed": "fail", + "tag": "validity", + "description": "min of col3 value must be greater than 0", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], True, # write to table False, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - True, # query_dq - True, # source_query_dq - True, # final_query_dq spark.createDataFrame( [ {"col1": 1, "col2": "a", "col3": 4}, @@ -1955,11 +1834,9 @@ def fixture_create_stats_table(): {"row_dq_status": "Passed", "source_agg_dq_status": "Skipped", "final_agg_dq_status": "Skipped", "run_status": "Passed", "source_query_dq_status": "Passed", "final_query_dq_status": "Passed"} # status - ), - ( - # Test case 20 + # Test case 21 # In this test case, dq run set for query_dq source_query_dq & # final_query_dq(ignore, fail) # with action_if_failed (ignore, fail) for query_dq @@ -1977,10 +1854,10 @@ def fixture_create_stats_table(): # row meets all row_dq_expectations ] ), - { # expectations rules - "row_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_mod_2", "column_name": "col3", @@ -1988,61 +1865,63 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "validity", "description": "col3 mod must equals to 0", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": False, "error_drop_threshold": "10", - }], - "query_dq_rules": [{ + }, + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "query_dq", "rule": "min_col1_threshold", "column_name": "col1", - "expectation": "(select min(col1) from test_table) > 10", + "expectation": "(select min(col1) from test_final_table_view) > 10", + "action_if_failed": "ignore", + "tag": "strict", + "description": "min of col1 value must be greater than 10", "enable_for_source_dq_validation": True, "enable_for_target_dq_validation": False, - "action_if_failed": "ignore", + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "max_col1_threshold", + "column_name": "col1", + "expectation": "(select max(col1) from test_final_table_view) > 100", + "action_if_failed": "fail", "tag": "strict", - "description": "min of col1 value must be greater than 10" + "description": "max of col1 value must be greater than 100", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "query_dq", - "rule": "max_col1_threshold", - "column_name": "col1", - "expectation": "(select max(col1) from target_test_table) > 100", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "action_if_failed": "fail", - "tag": "strict", - "description": "max of col1 value must be greater than 100" - }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "query_dq", - "rule": "min_col3_threshold", - "column_name": "col3", - "expectation": "(select min(col3) from target_test_table) > 0", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "action_if_failed": "ignore", - "tag": "validity", - "description": "min of col3 value must be greater than 0" - } - ], - "target_table_name": "dq_spark.test_final_table" - - }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "min_col3_threshold", + "column_name": "col3", + "expectation": "(select min(col3) from test_final_table_view) > 0", + "action_if_failed": "ignore", + "tag": "validity", + "description": "min of col3 value must be greater than 0", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], True, # write to table False, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - True, # query_dq - True, # source_query_dq - True, # final_query_dq SparkExpectationsMiscException, # expected result 3, # input count 1, # error count @@ -2066,10 +1945,9 @@ def fixture_create_stats_table(): {"row_dq_status": "Passed", "source_agg_dq_status": "Skipped", "final_agg_dq_status": "Skipped", "run_status": "Failed", "source_query_dq_status": "Passed", "final_query_dq_status": "Failed"} # status - ), ( - # Test case 20 + # Test case 22 # In this test case, dq run set for query_dq source_query_dq & # final_query_dq(ignore, fail) # with action_if_failed (ignore, fail) for query_dq @@ -2087,23 +1965,26 @@ def fixture_create_stats_table(): # row meets all row_dq_expectations ] ), - { # expectations rules - "agg_dq_rules": [{ + [ + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "agg_dq", "rule": "col3_max_value", "column_name": "col3", "expectation": "max(col3) > 1", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, "action_if_failed": "fail", "tag": "validity", - "description": "col3 mod must equals to 0" - }], - "row_dq_rules": [{ + "description": "col3 mod must equals to 0", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + }, + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "row_dq", "rule": "col3_mod_2", "column_name": "col3", @@ -2111,48 +1992,48 @@ def fixture_create_stats_table(): "action_if_failed": "drop", "tag": "validity", "description": "col3 mod must equals to 0", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": False, + "is_active": True, "enable_error_drop_alert": False, "error_drop_threshold": "100", - }], - "query_dq_rules": [{ + }, + { "product_id": "product1", - "target_table_name": "dq_spark.test_table", + "table_name": "dq_spark.test_final_table", "rule_type": "query_dq", "rule": "count_col1_threshold", "column_name": "col1", - "expectation": "(select count(col1) from test_table) > 3", + "expectation": "(select count(col1) from test_final_table_view) > 3", + "action_if_failed": "ignore", + "tag": "strict", + "description": "count of col1 value must be greater than 3", "enable_for_source_dq_validation": True, "enable_for_target_dq_validation": False, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "col3_positive_threshold", + "column_name": "col1", + "expectation": "(select count(case when col3>0 then 1 else 0 end) from " + "test_final_table_view) > 10", "action_if_failed": "ignore", "tag": "strict", - "description": "count of col1 value must be greater than 3" - }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "query_dq", - "rule": "col3_positive_threshold", - "column_name": "col1", - "expectation": "(select count(case when col3>0 then 1 else 0 end) from target_test_table) > 10", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "action_if_failed": "ignore", - "tag": "strict", - "description": "count of col3 positive value must be greater than 10" - } - ], - "target_table_name": "dq_spark.test_final_table" - - }, + "description": "count of col3 positive value must be greater than 10", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": False, + "error_drop_threshold": "20", + } + ], True, # write to table False, # write to temp table - True, # row_dq - True, # agg_dq - True, # source_agg_dq - True, # final_agg_dq - True, # query_dq - True, # source_query_dq - True, # final_query_dq spark.createDataFrame( [ {"col1": 1, "col2": "a", "col3": 4}, @@ -2187,13 +2068,6 @@ def test_with_expectations(input_df, expectations, write_to_table, write_to_temp_table, - row_dq, - agg_dq, - source_agg_dq, - final_agg_dq, - query_dq, - source_query_dq, - final_query_dq, expected_output, input_count, error_count, @@ -2204,42 +2078,36 @@ def test_with_expectations(input_df, final_query_dq_res, dq_rules, status, - _fixture_spark_expectations, - _fixture_context, - _fixture_create_stats_table, + _fixture_create_database, _fixture_local_kafka_topic): - spark.conf.set("spark.sql.session.timeZone", "Etc/UTC") - spark_conf = {"spark.sql.session.timeZone": "Etc/UTC"} - options = {'mode': 'overwrite', "format": "delta"} - options_error_table = {'mode': 'overwrite', "format": "delta"} - input_df.createOrReplaceTempView("test_table") - _fixture_context._num_row_dq_rules = (dq_rules.get("rules").get("num_row_dq_rules")) - _fixture_context._num_dq_rules = (dq_rules.get("rules").get("num_dq_rules")) - _fixture_context._num_agg_dq_rules = (dq_rules.get("agg_dq_rules")) - _fixture_context._num_query_dq_rules = (dq_rules.get("query_dq_rules")) + spark.conf.set("spark.sql.session.timeZone", "Etc/UTC") + + rules_df = spark.createDataFrame(expectations) if len(expectations) > 0 else expectations + rules_df.show(truncate=False) if len(expectations) > 0 else None + + writer = WrappedDataFrameWriter().mode("append").format("parquet") + se = SparkExpectations(product_id="product1", + rules_df=rules_df, + stats_table="dq_spark.test_dq_stats_table", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + ) + se._context._run_date = "2022-12-27 10:00:00" + se._context._env = "local" + se._context.set_input_count(100) + se._context.set_output_count(100) + se._context.set_error_count(0) + se._context._run_id = "product1_run_test" # Decorate the mock function with required args - @_fixture_spark_expectations.with_expectations( - expectations, - write_to_table, - write_to_temp_table, - row_dq, - agg_dq={ - user_config.se_agg_dq: agg_dq, - user_config.se_source_agg_dq: source_agg_dq, - user_config.se_final_agg_dq: final_agg_dq, - }, - query_dq={ - user_config.se_query_dq: query_dq, - user_config.se_source_query_dq: source_query_dq, - user_config.se_final_query_dq: final_query_dq, - user_config.se_target_table_view: "target_test_table" - }, - spark_conf={**spark_conf, **{user_config.se_notifications_on_fail: False}}, - options=options, - options_error_table=options_error_table, + @se.with_expectations( + "dq_spark.test_final_table", + user_conf={user_config.se_notifications_on_fail: False}, + write_to_table=write_to_table, + write_to_temp_table=write_to_temp_table ) def get_dataset() -> DataFrame: return input_df @@ -2260,16 +2128,16 @@ def get_dataset() -> DataFrame: else: get_dataset() # decorated_func() - if row_dq is True and write_to_table is True: + if write_to_table is True: expected_output_df = expected_output.withColumn("run_id", lit("product1_run_test")) \ - .withColumn("run_date", to_timestamp(lit("2022-12-27 10:39:44"))) + .withColumn("run_date", to_timestamp(lit("2022-12-27 10:00:00"))) - error_table = spark.table("dq_spark.test_final_table_error") result_df = spark.table("dq_spark.test_final_table") - result_df.show(truncate=False) - assert result_df.orderBy("col2").collect() == expected_output_df.orderBy("col2").collect() - assert error_table.count() == error_count + + if spark.catalog.tableExists("dq_spark.test_final_table_error"): + error_table = spark.table("dq_spark.test_final_table_error") + assert error_table.count() == error_count stats_table = spark.table("test_dq_stats_table") row = stats_table.first() @@ -2289,6 +2157,10 @@ def get_dataset() -> DataFrame: assert row.dq_status.get("source_query_dq") == status.get("source_query_dq_status") assert row.dq_status.get("final_query_dq") == status.get("final_query_dq_status") assert row.dq_status.get("run_status") == status.get("run_status") + assert row.meta_dq_run_id == "product1_run_test" + assert row.meta_dq_run_date == datetime.date(2022, 12, 27) + assert row.meta_dq_run_datetime == datetime.datetime(2022, 12, 27, 10, 00, 00) + assert len(stats_table.columns) == 20 assert spark.read.format("kafka").option( "kafka.bootstrap.servers", "localhost:9092" @@ -2299,70 +2171,27 @@ def get_dataset() -> DataFrame: ).load().orderBy(col('timestamp').desc()).limit(1).selectExpr( "cast(value as string) as value").collect() == stats_table.selectExpr("to_json(struct(*)) AS value").collect() - spark.sql("drop table if exists test_final_table_error") - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db/test_final_table_error") - - -# @pytest.mark.parametrize("write_to_table", [(True), (False)]) -# def test_with_table_write_expectations( -# write_to_table, -# _fixture_create_database, -# _fixture_df, -# _fixture_expectations, -# _fixture_context, -# _fixture_dq_rules, -# _fixture_spark_expectations, -# _fixture_create_stats_table, -# _fixture_local_kafka_topic): -# _fixture_context._num_row_dq_rules = (_fixture_dq_rules.get("rules").get("num_row_dq_rules")) -# _fixture_context._num_dq_rules = (_fixture_dq_rules.get("rules").get("num_dq_rules")) -# _fixture_context._num_agg_dq_rules = (_fixture_dq_rules.get("agg_dq_rules")) -# _fixture_context._num_query_dq_rules = (_fixture_dq_rules.get("query_dq_rules")) -# -# # Create a mock object with a return value -# mock_func = Mock(return_value=_fixture_df) -# -# # Decorate the mock function with required args -# decorated_func = _fixture_spark_expectations.with_expectations( -# _fixture_expectations, -# write_to_table, -# agg_dq=None, -# query_dq=None, -# spark_conf={UserConfig.se_notifications_on_fail: False}, -# options={'mode': 'overwrite', "format": "delta"}, -# options_error_table={'mode': 'overwrite', "format": "delta"} -# )(mock_func) -# -# decorated_func() -# -# if write_to_table is True: -# assert "test_final_table" in [obj.name for obj in spark.catalog.listTables()] -# spark.sql("drop table if exists dq_spark.test_final_table") -# else: -# assert "test_final_table" not in [obj.name for obj in spark.catalog.listTables()] + # spark.sql("select * from dq_spark.test_final_table").show(truncate=False) + # spark.sql("select * from dq_spark.test_final_table_error").show(truncate=False) + # spark.sql("select * from dq_spark.test_dq_stats_table").show(truncate=False) + + for db in spark.catalog.listDatabases(): + if db.name != "default": + spark.sql(f"DROP DATABASE {db.name} CASCADE") + spark.sql("CLEAR CACHE") + + # os.system("rm -rf /tmp/hive/warehouse/dq_spark.db/test_final_table_error") @patch("spark_expectations.core.expectations.SparkExpectationsWriter.write_error_stats") def test_with_expectations_patch(_write_error_stats, _fixture_create_database, _fixture_spark_expectations, - _fixture_context, - _fixture_dq_rules, _fixture_df, - _fixture_expectations): - _fixture_context._num_row_dq_rules = (_fixture_dq_rules.get("rules").get("num_row_dq_rules")) - _fixture_context._num_dq_rules = (_fixture_dq_rules.get("rules").get("num_dq_rules")) - _fixture_context._num_agg_dq_rules = (_fixture_dq_rules.get("agg_dq_rules")) - _fixture_context._num_query_dq_rules = (_fixture_dq_rules.get("query_dq_rules")) - + _fixture_rules_df): decorated_func = _fixture_spark_expectations.with_expectations( - _fixture_expectations, - True, - agg_dq=None, - query_dq=None, - spark_conf={user_config.se_notifications_on_fail: False}, - options={'mode': 'overwrite', "format": "delta"}, - options_error_table={'mode': 'overwrite', "format": "delta"} + "dq_spark.test_final_table", + user_conf={user_config.se_notifications_on_fail: False}, )(Mock(return_value=_fixture_df)) decorated_func() @@ -2370,14 +2199,29 @@ def test_with_expectations_patch(_write_error_stats, _write_error_stats.assert_called_once_with() +def test_with_expectations_overwrite_writers( + _fixture_create_database, + _fixture_spark_expectations, + _fixture_df, + _fixture_rules_df): + modified_writer = WrappedDataFrameWriter().mode("overwrite").format("iceberg") + _fixture_spark_expectations.with_expectations( + "dq_spark.test_final_table", + user_conf={user_config.se_notifications_on_fail: False}, + target_and_error_table_writer=modified_writer + )(Mock(return_value=_fixture_df)) + + assert _fixture_spark_expectations._context.get_target_and_error_table_writer_config == modified_writer.build() + + def test_with_expectations_dataframe_not_returned_exception(_fixture_create_database, _fixture_spark_expectations, _fixture_df, - _fixture_expectations, + _fixture_rules_df, _fixture_local_kafka_topic): partial_func = _fixture_spark_expectations.with_expectations( - _fixture_expectations, - spark_conf={user_config.se_notifications_on_fail: False}, + "dq_spark.test_final_table", + user_conf={user_config.se_notifications_on_fail: False}, ) with pytest.raises(SparkExpectationsMiscException, @@ -2390,17 +2234,43 @@ def test_with_expectations_dataframe_not_returned_exception(_fixture_create_data # Decorate the mock function with required args decorated_func = partial_func(mock_func) decorated_func() + for db in spark.catalog.listDatabases(): + if db.name != "default": + spark.sql(f"DROP DATABASE {db.name} CASCADE") + spark.sql("CLEAR CACHE") def test_with_expectations_exception(_fixture_create_database, _fixture_spark_expectations, - _fixture_df, - _fixture_expectations, - _fixture_create_stats_table, _fixture_local_kafka_topic): - partial_func = _fixture_spark_expectations.with_expectations( - _fixture_expectations, - spark_conf={user_config.se_notifications_on_fail: False} + rules_dict = { + "product_id": "product1", + "table_name": "dq_spark.test_table", + "rule_type": "row_dq", + "rule": "col1_threshold", + "column_name": "col1", + "expectation": "col1 > 1", + "action_if_failed": "ignore", + "tag": "validity", + "description": "col1 value must be greater than 1", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "10" + } + rules_df = spark.createDataFrame([rules_dict]) + writer = WrappedDataFrameWriter().mode("append").format("delta") + se = SparkExpectations(product_id="product1", + rules_df=rules_df, + stats_table="dq_spark.test_dq_stats_table", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + ) + partial_func = se.with_expectations( + "dq_spark.test_final_table", + user_conf={user_config.se_notifications_on_fail: False} ) with pytest.raises(SparkExpectationsMiscException, @@ -2412,267 +2282,269 @@ def test_with_expectations_exception(_fixture_create_database, decorated_func = partial_func(mock_func) decorated_func() + for db in spark.catalog.listDatabases(): + if db.name != "default": + spark.sql(f"DROP DATABASE {db.name} CASCADE") + spark.sql("CLEAR CACHE") + + +# @patch('spark_expectations.core.expectations.SparkExpectationsNotify', autospec=True, +# spec_set=True) +# @patch('spark_expectations.notifications.push.spark_expectations_notify._notification_hook', autospec=True, +# spec_set=True) +def test_error_threshold_breach( + # _mock_notification_hook, _mock_spark_expectations_notify, + _fixture_create_database, + _fixture_local_kafka_topic +): + input_df = spark.createDataFrame( + [ + {"col1": 1, "col2": "a", "col3": 4}, + {"col1": 2, "col2": "b", "col3": 5}, + {"col1": 3, "col2": "c", 'col3': 6}, + ] + ) -@pytest.mark.parametrize("input_df, " - "expectations, " - "write_to_table, " - "write_to_temp_table, " - "row_dq, agg_dq, " - "source_agg_dq, " - "final_agg_dq," - "dq_rules ", - [ - ( - # In this test case, the action for failed rows is "ignore" & "drop", - # collect stats in the test_stats_table and - # log the error records into the error table. - spark.createDataFrame( - [ - {"col1": 1, "col2": "a", "col3": 4}, - # row doesn't meet expectations1(ignore) 2(drop), log into err & fnl - {"col1": 2, "col2": "b", "col3": 5}, - # row meets expectations1(ignore), log into final table - {"col1": 3, "col2": "c", 'col3': 6}, - # row doesnt'meets expectations1(ignore), log into final table - ] - ), - { # expectations rules - "row_dq_rules": [{ - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col3_threshold", - "column_name": "col3", - "expectation": "col3 > 6", - "action_if_failed": "ignore", - "tag": "strict", - "description": "col3 value must be greater than 6", - "enable_error_drop_alert": True, - "error_drop_threshold": "25", - }, - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col1_add_col3_threshold", - "column_name": "col1", - "expectation": "(col1+col3) > 6", - "action_if_failed": "drop", - "tag": "strict", - "description": "col1_add_col3 value must be greater than 6", - "enable_error_drop_alert": True, - "error_drop_threshold": "50", - } - ], - "agg_dq_rules": [{}], - "target_table_name": "dq_spark.test_final_table" - - }, - True, # write to table - True, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq_res - {"rules": {"num_dq_rules": 0, "num_row_dq_rules": 0}, - "query_dq_rules": {"num_final_query_dq_rules": 0, "num_source_query_dq_rules": 0, - "num_query_dq_rules": 0}, - "agg_dq_rules": {"num_source_agg_dq_rules": 0, "num_agg_dq_rules": 0, - "num_final_agg_dq_rules": 0}} # dq_rules - ) - ]) -@patch('spark_expectations.core.expectations.SparkExpectationsNotify', autospec=True, - spec_set=True) -@patch('spark_expectations.notifications.push.spark_expectations_notify._notification_hook', autospec=True, - spec_set=True) -def test_error_threshold_breach(_mock_notification_hook, _mock_spark_expectations_notify, input_df, - expectations, - write_to_table, - write_to_temp_table, - row_dq, - agg_dq, - source_agg_dq, - final_agg_dq, - dq_rules, - _fixture_spark_expectations, - _fixture_context, - _fixture_create_stats_table, - _fixture_local_kafka_topic): - spark.conf.set("spark.sql.session.timeZone", "Etc/UTC") - spark_conf = {"spark.sql.session.timeZone": "Etc/UTC"} - options = {'mode': 'overwrite', "format": "delta"} - options_error_table = {'mode': 'overwrite', "format": "delta"} - - # set neccessary parameters in context class or object - _fixture_context._num_row_dq_rules = (dq_rules.get("rules").get("num_row_dq_rules")) - _fixture_context._num_dq_rules = (dq_rules.get("rules").get("num_dq_rules")) - _fixture_context._num_agg_dq_rules = (dq_rules.get("agg_dq_rules")) - _fixture_context._num_query_dq_rules = (dq_rules.get("query_dq_rules")) - - # Decorate the mock function with required args - @_fixture_spark_expectations.with_expectations( - expectations, - write_to_table, - write_to_temp_table, - row_dq, - agg_dq={ - user_config.se_agg_dq: agg_dq, - user_config.se_source_agg_dq: source_agg_dq, - user_config.se_final_agg_dq: final_agg_dq, + rules = [ + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "row_dq", + "rule": "col1_add_col3_threshold", + "column_name": "col1", + "expectation": "(col1+col3) > 6", + "action_if_failed": "drop", + "tag": "strict", + "description": "col1_add_col3 value must be greater than 6", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "25" }, - query_dq=None, - spark_conf={ - user_config.se_notifications_on_fail: False, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "col3_positive_threshold", + "column_name": "col3", + "expectation": "(select count(case when col3>0 then 1 else 0 end) from test_final_table_view) > 10", + "action_if_failed": "ignore", + "tag": "strict", + "description": "count of col3 positive value must be greater than 10", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "10" + }] + + # create a PySpark DataFrame from the list of dictionaries + rules_df = spark.createDataFrame(rules) + + writer = WrappedDataFrameWriter().mode("append").format("delta") + + with patch( + 'spark_expectations.notifications.push.spark_expectations_notify.SparkExpectationsNotify' + '.notify_on_exceeds_of_error_threshold', + autospec=True, spec_set=True) as _mock_notification_hook: + se = SparkExpectations(product_id="product1", + rules_df=rules_df, + stats_table="dq_spark.test_dq_stats_table", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + ) + + from spark_expectations.config.user_config import Constants as user_config + conf = { + user_config.se_notifications_on_fail: True, user_config.se_notifications_on_error_drop_exceeds_threshold_breach: True, - user_config.se_notifications_on_error_drop_threshold: 10, - }, - options=options, - options_error_table=options_error_table - ) - def get_dataset() -> DataFrame: - return input_df + user_config.se_notifications_on_error_drop_threshold: 15, + } - get_dataset() + @se.with_expectations( + target_table="dq_spark.test_final_table", + write_to_table=True, + user_conf=conf + ) + def get_dataset() -> DataFrame: + return input_df - _mock_notification_hook.send_notification.assert_called_once() - # _mock_spark_expectations_notify("product_id", - # _fixture_context).notify_on_exceeds_of_error_threshold.assert_called_once() + get_dataset() + _mock_notification_hook.assert_called_once() + for db in spark.catalog.listDatabases(): + if db.name != "default": + spark.sql(f"DROP DATABASE {db.name} CASCADE") + spark.sql("CLEAR CACHE") -@pytest.mark.parametrize("input_df, " - "expectations, " - "write_to_table, " - "write_to_temp_table, " - "row_dq, agg_dq, " - "source_agg_dq, " - "final_agg_dq, " - "query_dq, " - "source_query_dq, " - "final_query_dq, " - "expected_output, " - "input_count, " - "error_count, " - "output_count ", - [( - # In this test case, dq run set for query_dq source_query_dq & - # final_query_dq(ignore, fail) - # with action_if_failed (ignore, fail) for query_dq - # collect stats in the test_stats_table, error into error_table & raise the error - spark.createDataFrame( - [ - # min of col1 must be greater than 10(ignore) - source_query_dq - # max of col1 must be greater than 100(fail) - final_query_dq - # min of col3 must be greater than 0(fail) - final_query_dq - {"col1": 1, "col2": "a", "col3": 4}, - # row meets all row_dq_expectations(drop) - {"col1": 2, "col2": "b", "col3": 5}, - # ow doesn't meet row_dq expectation1(drop) - {"col1": 3, "col2": "c", "col3": 6}, - # row meets all row_dq_expectations - ] - ), - { # expectations rules - "row_dq_rules": [{ - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "row_dq", - "rule": "col3_mod_2", - "column_name": "col3", - "expectation": "(col3 % 2) = 0", - "action_if_failed": "drop", - "tag": "validity", - "description": "col3 mod must equals to 0", - "enable_error_drop_alert": False, - "error_drop_threshold": "10", - }], - "query_dq_rules": [ - { - "product_id": "product1", - "target_table_name": "dq_spark.test_table", - "rule_type": "query_dq", - "rule": "col3_positive_threshold", - "column_name": "col1", - "expectation": "(select count(case when col3>0 then 1 else 0 end) from target_test_table) > 10", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "action_if_failed": "ignore", - "tag": "strict", - "description": "count of col3 positive value must be greater than 10" - } - ], - "target_table_name": "dq_spark.test_final_table" - - }, - True, # write to table - False, # write to temp table - True, # row_dq - False, # agg_dq - False, # source_agg_dq - False, # final_agg_dq - True, # query_dq - False, # source_query_dq - True, # final_query_dq - spark.createDataFrame( - [ - {"col1": 1, "col2": "a", "col3": 4}, - {"col1": 3, "col2": "c", "col3": 6}, - ] - ), # expected result - 3, # input count - 1, # error count - 2, # output count - )]) -def test_target_table_view_exception(input_df, - expectations, - write_to_table, - write_to_temp_table, - row_dq, - agg_dq, - source_agg_dq, - final_agg_dq, - query_dq, - source_query_dq, - final_query_dq, - expected_output, - input_count, - error_count, - output_count, - _fixture_spark_expectations, - _fixture_context, - _fixture_create_stats_table, + +def test_target_table_view_exception(_fixture_create_database, _fixture_local_kafka_topic): - spark.conf.set("spark.sql.session.timeZone", "Etc/UTC") - spark_conf = {"spark.sql.session.timeZone": "Etc/UTC"} - options = {'mode': 'overwrite', "format": "delta"} - options_error_table = {'mode': 'overwrite', "format": "delta"} + rules = [ + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "row_dq", + "rule": "col1_threshold", + "column_name": "col1", + "expectation": "col1 > 1", + "action_if_failed": "ignore", + "tag": "validity", + "description": "col1 value must be greater than 1", + "enable_for_source_dq_validation": True, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "10" + }, + { + "product_id": "product1", + "table_name": "dq_spark.test_final_table", + "rule_type": "query_dq", + "rule": "col3_positive_threshold", + "column_name": "col3", + "expectation": "(select count(case when col3>0 then 1 else 0 end) from target_test_table) > 10", + "action_if_failed": "ignore", + "tag": "strict", + "description": "count of col3 positive value must be greater than 10", + "enable_for_source_dq_validation": False, + "enable_for_target_dq_validation": True, + "is_active": True, + "enable_error_drop_alert": True, + "error_drop_threshold": "10" + }] - input_df.createOrReplaceTempView("test_table") + input_df = spark.createDataFrame( + [ + # min of col1 must be greater than 10(ignore) - source_query_dq + # max of col1 must be greater than 100(fail) - final_query_dq + # min of col3 must be greater than 0(fail) - final_query_dq + {"col1": 1, "col2": "a", "col3": 4}, + # row meets all row_dq_expectations(drop) + {"col1": 2, "col2": "b", "col3": 5}, + # ow doesn't meet row_dq expectation1(drop) + {"col1": 3, "col2": "c", "col3": 6}, + # row meets all row_dq_expectations + ] + ) + + # create a PySpark DataFrame from the list of dictionaries + rules_df = spark.createDataFrame(rules) + rules_df.createOrReplaceTempView("test_table") + + writer = WrappedDataFrameWriter().mode("append").format("delta") + se = SparkExpectations(product_id="product1", + rules_df=rules_df, + stats_table="dq_spark.test_dq_stats_table", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + ) + + @se.with_expectations( + target_table="dq_spark.test_final_table", + write_to_table=True, + target_table_view="test_table", - # Decorate the mock function with required args - @_fixture_spark_expectations.with_expectations( - expectations, - write_to_table, - write_to_temp_table, - row_dq, - agg_dq={ - user_config.se_agg_dq: agg_dq, - user_config.se_source_agg_dq: source_agg_dq, - user_config.se_final_agg_dq: final_agg_dq, - }, - query_dq={ - user_config.se_query_dq: query_dq, - user_config.se_source_query_dq: source_query_dq, - user_config.se_final_query_dq: final_query_dq, - user_config.se_target_table_view: "" - }, - spark_conf={**spark_conf, **{user_config.se_notifications_on_fail: False}}, - options=options, - options_error_table=options_error_table, ) def get_dataset() -> DataFrame: return input_df - input_df.show(truncate=False) - with pytest.raises(SparkExpectationsMiscException, match=r"error occurred while processing spark expectations .*"): get_dataset() # decorated_func() + + for db in spark.catalog.listDatabases(): + if db.name != "default": + spark.sql(f"DROP DATABASE {db.name} CASCADE") + spark.sql("CLEAR CACHE") + + +def test_spark_expectations_exception(): + writer = WrappedDataFrameWriter().mode("append").format("parquet") + with pytest.raises(SparkExpectationsMiscException, match=r"Input rules_df is not of dataframe type"): + SparkExpectations(product_id="product1", + rules_df=[], + stats_table="dq_spark.test_dq_stats_table", + stats_table_writer=writer, + target_and_error_table_writer=writer, + debugger=False, + ) + + +# [UnitTests for WrappedDataFrameWriter class] + +def reset_wrapped_dataframe_writer(): + writer = WrappedDataFrameWriter() + writer._mode = None + writer._format = None + writer._partition_by = [] + writer._options = {} + writer._bucket_by = {} + writer._sort_by = [] + + +def test_mode(): + assert WrappedDataFrameWriter().mode("overwrite")._mode == "overwrite" + + +def test_format(): + assert WrappedDataFrameWriter().format("parquet")._format == "parquet" + + +def test_partitionBy(): + assert WrappedDataFrameWriter().partitionBy("date", "region")._partition_by == ["date", "region"] + + +def test_option(): + assert WrappedDataFrameWriter().option("compression", "gzip")._options == {"compression": "gzip"} + + +def test_options(): + assert WrappedDataFrameWriter().options(path="/path/to/output", inferSchema="true")._options == { + "path": "/path/to/output", "inferSchema": "true"} + + +def test_bucketBy(): + assert WrappedDataFrameWriter().bucketBy(4, "country", "city")._bucket_by == {"num_buckets": 4, + "columns": ("country", "city")} + + +def test_build(): + writer = WrappedDataFrameWriter().mode("overwrite") \ + .format("parquet") \ + .partitionBy("date", "region") \ + .option("compression", "gzip") \ + .options(path="/path/to/output", inferSchema="true") \ + .bucketBy(4, "country", "city") \ + .sortBy("col1", "col2") + expected_config = { + "mode": "overwrite", + "format": "parquet", + "partitionBy": ["date", "region"], + "options": {"compression": "gzip", "path": "/path/to/output", "inferSchema": "true"}, + "bucketBy": {"num_buckets": 4, "columns": ("country", "city")}, + "sortBy": ["col1", "col2"], + } + assert writer.build() == expected_config + + +def test_build_some_values(): + writer = WrappedDataFrameWriter().mode("append").format("iceberg") + + expected_config = { + "mode": "append", + "format": "iceberg", + "partitionBy": [], + "options": {}, + "bucketBy": {}, + "sortBy": [] + } + assert writer.build() == expected_config + + +def test_delta_bucketby_exception(): + writer = WrappedDataFrameWriter().mode("append").format("delta").bucketBy(10, "a", "b") + with pytest.raises(SparkExpectationsMiscException, match=r"Bucketing is not supported for delta tables yet"): + writer.build() diff --git a/tests/notification/push/test_spark_expectations_notify.py b/tests/notification/push/test_spark_expectations_notify.py index c682ad14..96fa5b6d 100644 --- a/tests/notification/push/test_spark_expectations_notify.py +++ b/tests/notification/push/test_spark_expectations_notify.py @@ -2,11 +2,11 @@ import pytest from spark_expectations.notifications.push.spark_expectations_notify import SparkExpectationsNotify from spark_expectations.core.exceptions import SparkExpectationsMiscException - +from spark_expectations.core.context import SparkExpectationsContext +from unittest.mock import Mock @pytest.fixture(name="_fixture_mock_context") -@patch('spark_expectations.notifications.push.spark_expectations_notify.SparkExpectationsContext', autospec=True, - spec_set=True) -def fixture_mock_context(_context_mock): +def fixture_mock_context(): + _context_mock = Mock(spec=SparkExpectationsContext) _context_mock.get_table_name = "test_table" _context_mock.get_run_id = "test_run_id" _context_mock.get_run_date = "test_run_date" @@ -23,6 +23,7 @@ def fixture_mock_context(_context_mock): _context_mock.get_final_query_dq_status = "skipped" _context_mock.get_dq_run_status = "fail" _context_mock.get_row_dq_rule_type_name = "row_dq" + _context_mock.product_id = "product_id1" return _context_mock @@ -96,7 +97,7 @@ def test_notify_on_start_completion_failure(_fixture_mock_context,): _fixture_mock_context.get_notification_on_completion = False _fixture_mock_context.get_notification_on_fail = False - notify_handler = SparkExpectationsNotify("product_id1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) @notify_handler.notify_on_start_completion_failure( lambda: print("start notification sent"), @@ -115,7 +116,7 @@ def dummy_function(raise_exception=False): spec_set=True) def test_notify_on_start(_mock_notification_hook, _fixture_mock_context, _fixture_notify_start_expected_result): - notify_handler = SparkExpectationsNotify("product_id1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) # Call the function to be tested notify_handler.notify_on_start() @@ -132,7 +133,7 @@ def test_notify_on_start(_mock_notification_hook, _fixture_mock_context, _fixtur def test_notify_on_completion(_mock_notification_hook, _fixture_mock_context, _fixture_notify_completion_expected_result): - notify_handler = SparkExpectationsNotify("product_id1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) # Call the function to be tested notify_handler.notify_on_completion() @@ -148,7 +149,7 @@ def test_notify_on_completion(_mock_notification_hook, _fixture_mock_context, def test_notify_on_exceeds_of_error_threshold(_mock_notification_hook, _fixture_mock_context, _fixture_notify_error_threshold_expected_result): - notify_handler = SparkExpectationsNotify("product_id1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) # Call the function to be tested notify_handler.notify_on_exceeds_of_error_threshold() @@ -163,7 +164,7 @@ def test_notify_on_exceeds_of_error_threshold(_mock_notification_hook, _fixture_ spec_set=True) def test_notify_on_failure(_mock_notification_hook, _fixture_mock_context, _fixture_notify_fail_expected_result): - notify_handler = SparkExpectationsNotify("product_id1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) # Call the function to be tested notify_handler.notify_on_failure("exception") @@ -181,7 +182,7 @@ def test_notify_on_start_success(_mock_notification_hook, _fixture_mock_context, _fixture_mock_context.get_notification_on_completion = False _fixture_mock_context.get_notification_on_fail = False - notify_handler = SparkExpectationsNotify("product1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) @notify_handler.notify_on_start_completion_failure( notify_handler.notify_on_start, @@ -206,7 +207,7 @@ def test_notify_on_completion_success(_mock_notification_hook, _fixture_mock_con _fixture_mock_context.get_notification_on_completion = True _fixture_mock_context.get_notification_on_fail = False - notify_handler = SparkExpectationsNotify("product_id1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) @notify_handler.notify_on_start_completion_failure( lambda: print("start notification sent"), @@ -232,7 +233,7 @@ def test_notify_on_failure_success(_mock_notification_hook, _fixture_mock_contex _fixture_mock_context.get_notification_on_completion = False _fixture_mock_context.get_notification_on_fail = True - notify_handler = SparkExpectationsNotify("product_id1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) @notify_handler.notify_on_start_completion_failure( lambda: print("start notification sent"), @@ -260,7 +261,7 @@ def test_notify_on_start_completion_failure_exception(_fixture_mock_context): _fixture_mock_context.get_notification_on_completion = False _fixture_mock_context.get_notification_on_fail = False - notify_handler = SparkExpectationsNotify("product1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) @notify_handler.notify_on_start_completion_failure( lambda: print("start notification sent"), @@ -279,7 +280,7 @@ def dummy_function(raise_exception=False): def test_construct_message_for_each_rules(_fixture_mock_context): # Create an instance of the class under test - notify_handler = SparkExpectationsNotify("product1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) # Set up the test input rule_name = "Rule 1" @@ -296,7 +297,7 @@ def test_construct_message_for_each_rules(_fixture_mock_context): # Assert the constructed notification message expected_message = ( "Rule 1 has been exceeded above the threshold value(2.5%) for `row_data` quality validation\n" - "product_id: product1\n" + "product_id: product_id1\n" f"table_name: {_fixture_mock_context.get_table_name}\n" f"run_id: {_fixture_mock_context.get_run_id}\n" f"run_date: {_fixture_mock_context.get_run_date}\n" @@ -315,7 +316,7 @@ def test_construct_message_for_each_rules(_fixture_mock_context): def test_notify_on_exceeds_of_error_threshold_each_rules(_notification_hook, _fixture_mock_context): from unittest import mock - notify_handler = SparkExpectationsNotify("product1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) # Define test data message = "Test message" @@ -356,7 +357,7 @@ def test_notify_rules_exceeds_threshold( {"rule": "rule1", "failed_row_count": failed_row_count} ] - notify_handler = SparkExpectationsNotify("product1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) rules = { @@ -394,7 +395,7 @@ def test_notify_rules_exceeds_threshold_return_none( _fixture_mock_context.get_summarised_row_dq_res = None - notify_handler = SparkExpectationsNotify("product1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) # Call the function to test assert notify_handler.notify_rules_exceeds_threshold({}) == None @@ -417,7 +418,7 @@ def test_notify_rules_exceeds_threshold_exception(_fixture_mock_context): ] } - notify_handler = SparkExpectationsNotify("product1", _fixture_mock_context) + notify_handler = SparkExpectationsNotify(_fixture_mock_context) # Expecting a SparkExpectationsMiscException to be raised with pytest.raises(SparkExpectationsMiscException, match="An error occurred while sending notification " diff --git a/tests/sinks/plugins/test_delta_writer.py b/tests/sinks/plugins/test_delta_writer.py deleted file mode 100644 index 020e3b49..00000000 --- a/tests/sinks/plugins/test_delta_writer.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import pytest -from pyspark.sql.types import StructType, StructField, IntegerType, StringType -from spark_expectations.core import get_spark_session -from spark_expectations.core.exceptions import SparkExpectationsMiscException -from spark_expectations.sinks.plugins.delta_writer import SparkExpectationsDeltaWritePluginImpl - -spark = get_spark_session() - - -@pytest.fixture(name="_fixture_create_database") -def fixture_create_database(): - # drop and create dq_spark if exists - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db") - spark.sql("create database if not exists dq_spark") - spark.sql("use dq_spark") - - yield "dq_spark" - - # drop dq_spark if exists - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db") - - -@pytest.fixture(name="_fixture_dataset") -def fixture_dataset(): - # Create a mock dataframe - data = [(1, "John", 25), (2, "Jane", 30), (3, "Jim", 35)] - schema = StructType([ - StructField("id", IntegerType(), True), - StructField("name", StringType(), True), - StructField("age", IntegerType(), True) - ]) - return spark.createDataFrame(data, schema) - - -@pytest.fixture(name="_fixture_create_test_table") -def fixture_create_test_table(): - # drop if exist dq_spark database and create with test_dq_stats_table - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db") - spark.sql("create database if not exists dq_spark") - spark.sql("use dq_spark") - - spark.sql("drop table if exists test_table") - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db/test_table") - spark.sql( - """ - create table test_table_write ( - id integer, - name string, - age integer - ) - USING delta - """ - ) - - yield "test_table" - - spark.sql("drop table if exists test_table_write") - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db/test_table_write") - - # remove database - os.system("rm -rf /tmp/hive/warehouse/dq_spark.db") - - -# Write the test function -def test_writer(_fixture_create_database, _fixture_dataset, _fixture_create_test_table): - delta_writer_handler = SparkExpectationsDeltaWritePluginImpl() - - write_args = { - "stats_df": _fixture_dataset, - "table_name": "dq_spark.test_table_write" - } - - delta_writer_handler.writer(_write_args=write_args) - - expected_df = spark.table("test_table_write") - - assert expected_df.orderBy("id").collect() == _fixture_dataset.orderBy("id").collect() - - -def test_writer_exception(_fixture_create_database, _fixture_create_test_table): - delta_writer_handler = SparkExpectationsDeltaWritePluginImpl() - - write_args = { - "table_name": "dq_spark.test_table" - } - - with pytest.raises(SparkExpectationsMiscException, - match=r"error occurred while saving data into delta stats table .*"): - delta_writer_handler.writer(_write_args=write_args) diff --git a/tests/sinks/plugins/test_nsp_writer.py b/tests/sinks/plugins/test_kafka_writer.py similarity index 97% rename from tests/sinks/plugins/test_nsp_writer.py rename to tests/sinks/plugins/test_kafka_writer.py index 38de0077..54ba3a25 100644 --- a/tests/sinks/plugins/test_nsp_writer.py +++ b/tests/sinks/plugins/test_kafka_writer.py @@ -10,7 +10,7 @@ @pytest.fixture(name="_fixture_local_kafka_topic") -def fixture_setup_local_nsp_topic(): +def fixture_setup_local_kafka_topic(): current_dir = os.path.dirname(os.path.abspath(__file__)) if os.getenv('UNIT_TESTING_ENV') != "spark_expectations_unit_testing_on_github_actions": @@ -82,5 +82,5 @@ def test_kafka_writer_exception(_fixture_local_kafka_topic, _fixture_dataset): } } - with pytest.raises(SparkExpectationsMiscException, match=r"error occurred while saving data into NSP .*"): + with pytest.raises(SparkExpectationsMiscException, match=r"error occurred while saving data into kafka .*"): delta_writer_handler.writer(_write_args=write_args) diff --git a/tests/sinks/test__init__.py b/tests/sinks/test__init__.py index 01bd5559..8ebd055a 100644 --- a/tests/sinks/test__init__.py +++ b/tests/sinks/test__init__.py @@ -4,9 +4,6 @@ from pyspark.sql.types import StructType, StructField, IntegerType, StringType from spark_expectations.core import get_spark_session from spark_expectations.sinks import get_sink_hook, _sink_hook -from spark_expectations.sinks.plugins.delta_writer import ( - SparkExpectationsDeltaWritePluginImpl, -) from spark_expectations.sinks.plugins.kafka_writer import ( SparkExpectationsKafkaWritePluginImpl, ) @@ -29,7 +26,7 @@ def fixture_create_database(): @pytest.fixture(name="_fixture_local_kafka_topic") -def fixture_setup_local_nsp_topic(): +def fixture_setup_local_kafka_topic(): current_dir = os.path.dirname(os.path.abspath(__file__)) if os.getenv('UNIT_TESTING_ENV') != "spark_expectations_unit_testing_on_github_actions": @@ -67,17 +64,15 @@ def fixture_dataset(): def test_get_sink_hook(): pm = get_sink_hook() # Check that the correct number of plugins have been registered - assert len(pm.list_name_plugin()) == 2 + assert len(pm.list_name_plugin()) == 1 # Check that the correct plugins have been registered - assert isinstance(pm.get_plugin("spark_expectations_delta_write"), SparkExpectationsDeltaWritePluginImpl) assert isinstance(pm.get_plugin("spark_expectations_kafka_write"), SparkExpectationsKafkaWritePluginImpl) def test_sink_hook_write(_fixture_create_database, _fixture_local_kafka_topic, _fixture_dataset): write_args = { "stats_df": _fixture_dataset, - "table_name": "dq_spark.test_table", "kafka_write_options": { "kafka.bootstrap.servers": "localhost:9092", "topic": "dq-sparkexpectations-stats-local", @@ -89,7 +84,7 @@ def test_sink_hook_write(_fixture_create_database, _fixture_local_kafka_topic, _ _sink_hook.writer(_write_args=write_args) - expected_delta_df = spark.table("dq_spark.test_table") + # expected_delta_df = spark.table("dq_spark.test_table") expected_kafka_df = spark.read.format("kafka").option( "kafka.bootstrap.servers", "localhost:9092" @@ -100,6 +95,6 @@ def test_sink_hook_write(_fixture_create_database, _fixture_local_kafka_topic, _ ).load().orderBy(col('timestamp').desc()).limit(1).selectExpr( "cast(value as string) as stats_records") - assert expected_delta_df.collect() == _fixture_dataset.collect() + # assert expected_delta_df.collect() == _fixture_dataset.collect() assert expected_kafka_df.collect() == _fixture_dataset.selectExpr( "cast(to_json(struct(*)) as string) AS stats_records").collect() diff --git a/tests/sinks/utils/test_collect_statistics.py b/tests/sinks/utils/test_collect_statistics.py index 41019840..5106af8e 100644 --- a/tests/sinks/utils/test_collect_statistics.py +++ b/tests/sinks/utils/test_collect_statistics.py @@ -6,12 +6,15 @@ from spark_expectations.sinks.utils.collect_statistics import SparkExpectationsCollectStatistics from spark_expectations.sinks.utils.writer import SparkExpectationsWriter from spark_expectations.core.exceptions import SparkExpectationsMiscException +from unittest.mock import Mock +from spark_expectations.core.context import SparkExpectationsContext +from spark_expectations.core.expectations import WrappedDataFrameWriter spark = get_spark_session() @pytest.fixture(name="_fixture_local_kafka_topic") -def fixture_setup_local_nsp_topic(): +def fixture_setup_local_kafka_topic(): current_dir = os.path.dirname(os.path.abspath(__file__)) if os.getenv('UNIT_TESTING_ENV') != "spark_expectations_unit_testing_on_github_actions": @@ -195,10 +198,10 @@ def fixture_create_stats_table(): "error_percentage": 100.0, }) ]) -@patch('spark_expectations.sinks.utils.writer.SparkExpectationsContext', autospec=True, spec_set=True) -def test_collect_stats_on_success_failure(_mock_context, input_record, +def test_collect_stats_on_success_failure(input_record, expected_result, _fixture_local_kafka_topic, _fixture_create_stats_table): # create mock _context object + _mock_context = Mock(spec=SparkExpectationsContext) setattr(_mock_context, "get_dq_stats_table_name", "test_dq_stats_table") setattr(_mock_context, "get_run_date_name", "meta_dq_run_date") setattr(_mock_context, "get_run_date_time_name", "meta_dq_run_datetime") @@ -250,9 +253,17 @@ def test_collect_stats_on_success_failure(_mock_context, input_record, input_record.get("dq_rules").get("agg_dq_rules")) setattr(_mock_context, "get_num_query_dq_rules", input_record.get("dq_rules").get("query_dq_rules")) + setattr(_mock_context, "_stats_table_writer_config", + WrappedDataFrameWriter().mode("overwrite").format("delta").build()) + setattr(_mock_context, 'get_stats_table_writer_config', + WrappedDataFrameWriter().mode("overwrite").format("delta").build()) + setattr(_mock_context, 'get_se_streaming_stats_dict', {'se.enable.streaming': True}) - writer = SparkExpectationsWriter("product1", _mock_context) - statistics_writer_obj = SparkExpectationsCollectStatistics("product1", _mock_context, writer) + _mock_context.spark = spark + _mock_context.product_id = 'product1' + + writer = SparkExpectationsWriter(_mock_context) + statistics_writer_obj = SparkExpectationsCollectStatistics(_mock_context, writer) @statistics_writer_obj.collect_stats_on_success_failure() def exec_func(): @@ -406,10 +417,11 @@ def exec_func(): "error_percentage": 100.0, }) ]) -@patch('spark_expectations.sinks.utils.writer.SparkExpectationsContext', autospec=True, spec_set=True) -def test_collect_stats_on_success_failure_exception(_mock_context, input_record, + +def test_collect_stats_on_success_failure_exception(input_record, expected_result, _fixture_local_kafka_topic, _fixture_create_stats_table): + _mock_context = Mock(spec=SparkExpectationsContext) setattr(_mock_context, "get_dq_stats_table_name", "test_dq_stats_table") setattr(_mock_context, "get_run_date_name", "meta_dq_run_date") setattr(_mock_context, "get_run_date_time_name", "meta_dq_run_datetime") @@ -461,9 +473,16 @@ def test_collect_stats_on_success_failure_exception(_mock_context, input_record, input_record.get("dq_rules").get("agg_dq_rules")) setattr(_mock_context, "get_num_query_dq_rules", input_record.get("dq_rules").get("query_dq_rules")) + _mock_context.spark = spark + _mock_context.product_id = "product1" + setattr(_mock_context, "_stats_table_writer_config", + WrappedDataFrameWriter().mode("overwrite").format("delta").build()) + setattr(_mock_context, 'get_stats_table_writer_config', + WrappedDataFrameWriter().mode("overwrite").format("delta").build()) + setattr(_mock_context, 'get_se_streaming_stats_dict', {'se.enable.streaming': True}) - writer = SparkExpectationsWriter("product1", _mock_context) - statistics_writer_obj = SparkExpectationsCollectStatistics("product1", _mock_context, writer) + writer = SparkExpectationsWriter(_mock_context) + statistics_writer_obj = SparkExpectationsCollectStatistics(_mock_context, writer) @statistics_writer_obj.collect_stats_on_success_failure() def func_exception(): @@ -472,7 +491,6 @@ def func_exception(): with pytest.raises(SparkExpectationsMiscException): func_exception() - stats_table = spark.table("test_dq_stats_table") stats_table = spark.table("test_dq_stats_table") assert stats_table.count() == 1 row = stats_table.first() diff --git a/tests/sinks/utils/test_writer.py b/tests/sinks/utils/test_writer.py index 3762e943..f2452126 100644 --- a/tests/sinks/utils/test_writer.py +++ b/tests/sinks/utils/test_writer.py @@ -1,5 +1,8 @@ import os -from unittest.mock import patch +import unittest.mock +from unittest.mock import patch, Mock + +import pyspark.sql import pytest from pyspark.sql.functions import col from pyspark.sql.functions import lit, to_timestamp @@ -10,12 +13,13 @@ SparkExpectationsMiscException, SparkExpectationsUserInputOrConfigInvalidException ) +from spark_expectations.core.expectations import WrappedDataFrameWriter spark = get_spark_session() @pytest.fixture(name="_fixture_local_kafka_topic") -def fixture_setup_local_nsp_topic(): +def fixture_setup_local_kafka_topic(): current_dir = os.path.dirname(os.path.abspath(__file__)) if os.getenv('UNIT_TESTING_ENV') != "spark_expectations_unit_testing_on_github_actions": @@ -44,23 +48,26 @@ def fixture_employee_df(): @pytest.fixture(name="_fixture_writer") -@patch('spark_expectations.sinks.utils.writer.SparkExpectationsContext', autospec=True, spec_set=True) -def fixture_writer(mock_context): +def fixture_writer(): # create mock _context object + mock_context = Mock(spec=SparkExpectationsContext) setattr(mock_context, "get_dq_stats_table_name", "test_dq_stats_table") setattr(mock_context, "get_run_date", "2022-12-27 10:39:44") setattr(mock_context, "get_run_id", "product1_run_test") setattr(mock_context, "get_run_id_name", "meta_dq_run_id") setattr(mock_context, "get_run_date_name", "meta_dq_run_date") + mock_context.spark = spark + mock_context.product_id='product1' # Create an instance of the class and set the product_id - return SparkExpectationsWriter("product1", mock_context) + return SparkExpectationsWriter(mock_context) @pytest.fixture(name="_fixture_create_employee_table") def fixture_create_employee_table(): # drop if exist dq_spark database and create with employee_table os.system("rm -rf /tmp/hive/warehouse/dq_spark.db") + spark.sql("drop database IF EXISTS dq_spark cascade") spark.sql("create database if not exists dq_spark") spark.sql("use dq_spark") spark.sql("create table employee_table USING delta") @@ -75,6 +82,7 @@ def fixture_create_employee_table(): def fixture_create_stats_table(): # drop if exist dq_spark database and create with test_dq_stats_table os.system("rm -rf /tmp/hive/warehouse/dq_spark.db") + spark.sql("DROP DATABASE IF EXISTS dq_spark CASCADE") spark.sql("create database if not exists dq_spark") spark.sql("use dq_spark") @@ -149,36 +157,43 @@ def fixture_expected_dq_dataset(): .withColumn("meta_dq_run_date", to_timestamp(lit("2022-12-27 10:39:44"))) -@pytest.mark.parametrize('table_name, spark_conf, options, expected_count', - [('employee_table', {"spark.sql.session.timeZone": "Etc/UTC"}, - {'mode': 'overwrite', 'partitionBy': ['department'], "format": "delta", - "overwriteSchema": "true"}, 1000), - ('employee_table', {"spark.sql.session.timeZone": "Etc/UTC"}, - {'mode': 'append', "format": "delta", "mergeSchema": "true"}, 1000) +@pytest.mark.parametrize('table_name, options, expected_count', + [('employee_table', + {'mode': 'overwrite', 'partitionBy': ['department'], "format": "parquet", + 'bucketBy': {'numBuckets':2,'colName':'business_unit'}, 'sortBy': ["eeid"], 'options': {"overwriteSchema": "true", "mergeSchema": "true"}}, 1000), + ('employee_table', + {'mode': 'append', "format": "delta", 'partitionBy': [], 'bucketBy': {}, 'sortBy': [], 'options': {"mergeSchema": "true"}}, + 1000) ]) def test_save_df_as_table(table_name, - spark_conf, options, expected_count, _fixture_employee, _fixture_writer, _fixture_create_employee_table): - _fixture_writer.save_df_as_table(_fixture_employee, table_name, spark_conf, options) + _fixture_writer.save_df_as_table(_fixture_employee, table_name, options, False) assert expected_count == spark.sql(f"select * from {table_name}").count() # Assert # _spark_set.assert_called_with('spark.sql.session.timeZone', 'Etc/UTC') +@patch('pyspark.sql.DataFrameWriter.save', autospec=True, spec_set=True) +def test_save_df_to_table_bq(save, _fixture_writer, _fixture_employee, _fixture_create_employee_table): + _fixture_writer.save_df_as_table(_fixture_employee, 'employee_table', {'mode': 'overwrite', 'format': 'bigquery', + 'partitionBy':[], 'bucketBy': {}, 'sortBy':[], 'options':{}}) + save.assert_called_once_with(unittest.mock.ANY) + + @pytest.mark.parametrize('table_name, options', [('employee_table', {'mode': 'overwrite', 'partitionBy': ['department'], "format": "delta", - "overwriteSchema": "true"}), + "overwriteSchema": "true", 'options': {}}), ('employee_table', {'mode': 'append', "format": "delta", - "mergeSchema": "true"}) + "mergeSchema": "true", 'options': {}}) ]) @patch('spark_expectations.sinks.utils.writer.SparkExpectationsWriter.save_df_as_table', autospec=True, spec_set=True) def test_write_df_to_table(save_df_as_table, @@ -188,13 +203,13 @@ def test_write_df_to_table(save_df_as_table, _fixture_writer, _fixture_create_employee_table): # Test the function with valid input - _fixture_writer.write_df_to_table(_fixture_employee, table_name, options=options) + _fixture_writer.save_df_as_table(_fixture_employee, table_name, config=options) save_df_as_table.assert_called_once_with( - _fixture_writer, _fixture_employee, table_name, {"spark.sql.session.timeZone": "Etc/UTC"}, options + _fixture_writer, _fixture_employee, table_name, options ) -@pytest.mark.parametrize("input_record, expected_result", [ +@pytest.mark.parametrize("input_record, expected_result, writer_config", [ ({ "input_count": 100, "error_count": 10, @@ -237,7 +252,7 @@ def test_write_df_to_table(save_df_as_table, "output_percentage": 90.0, "success_percentage": 90.0, "error_percentage": 10.0 - }), + }, None), ({ "input_count": 100, "error_count": 10, @@ -273,7 +288,7 @@ def test_write_df_to_table(save_df_as_table, "output_percentage": 95.0, "success_percentage": 90.0, "error_percentage": 10.0, - }), + }, None), ({ "input_count": 100, "error_count": 100, @@ -311,7 +326,7 @@ def test_write_df_to_table(save_df_as_table, "output_percentage": 100.0, "success_percentage": 0.0, "error_percentage": 100.0, - }), + }, None), ({ "input_count": 100, "error_count": 100, @@ -347,7 +362,7 @@ def test_write_df_to_table(save_df_as_table, "output_percentage": 0.0, "success_percentage": 0.0, "error_percentage": 100.0, - }), + }, None), ({ "input_count": 100, "error_count": 100, @@ -387,7 +402,7 @@ def test_write_df_to_table(save_df_as_table, "output_percentage": 0.0, "success_percentage": 0.0, "error_percentage": 100.0 - }), + }, None), ({ "input_count": 100, "error_count": 100, @@ -424,15 +439,15 @@ def test_write_df_to_table(save_df_as_table, "output_percentage": 0.0, "success_percentage": 0.0, "error_percentage": 100.0, - }) + }, {'mode': 'append', "format": "bigquery", 'partitionBy': [], 'bucketBy': {}, 'sortBy': [], 'options': {"mergeSchema": "true"}}) ]) -@patch('spark_expectations.sinks.utils.writer.SparkExpectationsContext', autospec=True, spec_set=True) -def test_write_error_stats(_mock_context, - input_record, +def test_write_error_stats(input_record, expected_result, + writer_config, _fixture_create_stats_table, _fixture_local_kafka_topic): # create mock _context object + _mock_context = Mock(spec=SparkExpectationsContext) setattr(_mock_context, "get_dq_stats_table_name", "test_dq_stats_table") setattr(_mock_context, "get_run_date_name", "meta_dq_run_date") setattr(_mock_context, "get_run_date_time_name", "meta_dq_run_datetime") @@ -484,67 +499,81 @@ def test_write_error_stats(_mock_context, input_record.get("dq_rules").get("agg_dq_rules")) setattr(_mock_context, "get_num_query_dq_rules", input_record.get("dq_rules").get("query_dq_rules")) + setattr(_mock_context, 'get_dq_stats_table_name', 'test_dq_stats_table') - _fixture_writer = SparkExpectationsWriter("product1", _mock_context) + if writer_config is None: + setattr(_mock_context, "_stats_table_writer_config", WrappedDataFrameWriter().mode("overwrite").format("delta").build()) + setattr(_mock_context, 'get_stats_table_writer_config', WrappedDataFrameWriter().mode("overwrite").format("delta").build()) + else: + setattr(_mock_context, "_stats_table_writer_config", writer_config) + setattr(_mock_context, 'get_stats_table_writer_config', writer_config) - # Call the function being tested with some test input data - _fixture_writer.write_error_stats() + _mock_context.spark = spark + _mock_context.product_id = 'product1' - # Assert - stats_table = spark.table("test_dq_stats_table") - assert stats_table.count() == 1 - row = stats_table.first() - assert row.product_id == "product1" - assert row.table_name == "employee_table" - assert row.input_count == input_record.get("input_count") - assert row.error_count == input_record.get("error_count") - assert row.output_count == input_record.get("output_count") - assert row.output_percentage == expected_result.get("output_percentage") - assert row.success_percentage == expected_result.get("success_percentage") - assert row.error_percentage == expected_result.get("error_percentage") - assert row.source_agg_dq_results == input_record.get("source_agg_results") - assert row.final_agg_dq_results == input_record.get("final_agg_results") - assert row.source_query_dq_results == input_record.get("source_query_dq_results") - assert row.final_query_dq_results == input_record.get("final_query_dq_results") - assert row.dq_rules == input_record.get("dq_rules") - # assert row.dq_run_time == input_record.get("dq_run_time") - assert row.dq_status == input_record.get("status") - assert row.meta_dq_run_id == "product1_run_test" - - assert spark.read.format("kafka").option( - "kafka.bootstrap.servers", "localhost:9092" - ).option("subscribe", "dq-sparkexpectations-stats").option( - "startingOffsets", "earliest" - ).option( - "endingOffsets", "latest" - ).load().orderBy(col('timestamp').desc()).limit(1).selectExpr( - "cast(value as string) as value").collect() == stats_table.selectExpr( - "to_json(struct(*)) AS value").collect() - - # Assert spark conf.set - # _spark_set.assert_called_with('spark.sql.session.timeZone', 'Etc/UTC') + _fixture_writer = SparkExpectationsWriter(_mock_context) + if writer_config and writer_config['format'] == 'bigquery': + patcher = patch('pyspark.sql.DataFrameWriter.save', autospec=True, spec_set=True) + mock_bq = patcher.start() + setattr(_mock_context, 'get_se_streaming_stats_dict', {'se.enable.streaming': False}) + _fixture_writer.write_error_stats() + mock_bq.assert_called_with(unittest.mock.ANY) -@pytest.mark.parametrize('table_name, rule_type, spark_conf, options', + else: + setattr(_mock_context, 'get_se_streaming_stats_dict', {'se.enable.streaming': True}) + _fixture_writer.write_error_stats() + stats_table = spark.table("test_dq_stats_table") + assert stats_table.count() == 1 + row = stats_table.first() + assert row.product_id == "product1" + assert row.table_name == "employee_table" + assert row.input_count == input_record.get("input_count") + assert row.error_count == input_record.get("error_count") + assert row.output_count == input_record.get("output_count") + assert row.output_percentage == expected_result.get("output_percentage") + assert row.success_percentage == expected_result.get("success_percentage") + assert row.error_percentage == expected_result.get("error_percentage") + assert row.source_agg_dq_results == input_record.get("source_agg_results") + assert row.final_agg_dq_results == input_record.get("final_agg_results") + assert row.source_query_dq_results == input_record.get("source_query_dq_results") + assert row.final_query_dq_results == input_record.get("final_query_dq_results") + assert row.dq_rules == input_record.get("dq_rules") + # assert row.dq_run_time == input_record.get("dq_run_time") + assert row.dq_status == input_record.get("status") + assert row.meta_dq_run_id == "product1_run_test" + + assert spark.read.format("kafka").option( + "kafka.bootstrap.servers", "localhost:9092" + ).option("subscribe", "dq-sparkexpectations-stats").option( + "startingOffsets", "earliest" + ).option( + "endingOffsets", "latest" + ).load().orderBy(col('timestamp').desc()).limit(1).selectExpr( + "cast(value as string) as value").collect() == stats_table.selectExpr( + "to_json(struct(*)) AS value").collect() + + + + + +@pytest.mark.parametrize('table_name, rule_type', [('test_error_table', - 'row_dq', - {"spark.sql.session.timeZone": "Etc/UTC"}, - {'mode': 'overwrite', "format": "delta"} + 'row_dq' ) ]) def test_write_error_records_final(table_name, rule_type, - spark_conf, - options, _fixture_dq_dataset, _fixture_expected_dq_dataset, _fixture_writer): + config = WrappedDataFrameWriter().mode("overwrite").format("delta").build() + + setattr(_fixture_writer._context, 'get_target_and_error_table_writer_config', config) # invoke the write_error_records_final method with the test fixtures as arguments result, _df = _fixture_writer.write_error_records_final(_fixture_dq_dataset, table_name, - rule_type, - spark_conf, - options) + rule_type) # error_df = spark.table("test_error_table") # assert that the returned value is the expected number of rows in the error table @@ -554,25 +583,20 @@ def test_write_error_records_final(table_name, assert _df.orderBy("id").collect() == _fixture_expected_dq_dataset.orderBy("id").collect() -@pytest.mark.parametrize('table_name, rule_type, spark_conf, options', - [('test_error_table', 'row_dq', {"spark.sql.session.timeZone": "Etc/UTC"}, - {'mode': 'overwrite', "format": "delta"}) +@pytest.mark.parametrize('table_name, rule_type', + [('test_error_table', 'row_dq') ]) @patch('spark_expectations.sinks.utils.writer.SparkExpectationsWriter.save_df_as_table', autospec=True, spec_set=True) def test_write_error_records_final_dependent(save_df_as_table, table_name, rule_type, - spark_conf, - options, _fixture_dq_dataset, _fixture_expected_error_dataset, _fixture_writer): # invoke the write_error_records_final method with the test fixtures as arguments result, _df = _fixture_writer.write_error_records_final(_fixture_dq_dataset, table_name, - rule_type, - spark_conf, - options) + rule_type) # assert that the returned value is the expected number of rows in the error table assert result == 3 @@ -584,9 +608,7 @@ def test_write_error_records_final_dependent(save_df_as_table, .withColumn("meta_dq_run_date", lit("2022-12-27 10:39:44")) \ .orderBy("id").collect() assert save_df_args[0][2] == table_name - assert save_df_args[0][3] == spark_conf - assert save_df_args[0][4] == options - save_df_as_table.assert_called_once_with(_fixture_writer, save_df_args[0][1], table_name, spark_conf, options) + save_df_as_table.assert_called_once_with(_fixture_writer, save_df_args[0][1], table_name, _fixture_writer._context.get_target_and_error_table_writer_config) @pytest.mark.parametrize("test_data, expected_result", [ @@ -612,8 +634,8 @@ def test_write_error_records_final_dependent(save_df_as_table, ) ]) def test_generate_summarised_row_dq_res(test_data, expected_result): - context = SparkExpectationsContext("product1") - writer = SparkExpectationsWriter("product1", context) + context = SparkExpectationsContext("product1", spark) + writer = SparkExpectationsWriter(context) # Create test DataFrame test_df = spark.createDataFrame(test_data) @@ -637,14 +659,22 @@ def test_generate_summarised_row_dq_res(test_data, expected_result): "description": "description1", "rule_type": "row_dq", "error_drop_threshold": "10", - } + }, + { + "rule": "rule2", + "enable_error_drop_alert": True, + "action_if_failed": "drop", + "description": "description1", + "rule_type": "row_dq", + "error_drop_threshold": "10", + }, ] }, [ - {"rule": "rule1", "failed_row_count": 10} + {"rule": "rule1", "failed_row_count": 10}, ], [ @@ -778,8 +808,8 @@ def test_generate_rules_exceeds_threshold(dq_rules, summarised_row_dq, expected_result, ): - _context = SparkExpectationsContext("product1") - _writer = SparkExpectationsWriter("product1", _context) + _context = SparkExpectationsContext("product1", spark) + _writer = SparkExpectationsWriter(_context) _context.set_summarised_row_dq_res(summarised_row_dq) _context.set_input_count(100) @@ -787,6 +817,7 @@ def test_generate_rules_exceeds_threshold(dq_rules, _writer.generate_rules_exceeds_threshold(dq_rules) assert _context.get_rules_exceeds_threshold == expected_result + @pytest.mark.parametrize("test_data", [ ( [ @@ -810,19 +841,9 @@ def test_save_df_as_table_exception(_fixture_employee, with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, match=r"error occurred while writing data in to the table .*"): _fixture_writer.save_df_as_table(_fixture_employee, "employee_table", - {"spark.sql.session.timeZone": "Etc/UTC"}, {'mode': 'insert', "format": "test", "mergeSchema": "true"}) -def test_write_df_to_table_exception(_fixture_employee, - _fixture_writer): - with pytest.raises(SparkExpectationsMiscException, - match=r"error occurred while writing data in to the table .*"): - _fixture_writer.write_df_to_table(_fixture_employee, "employee_table", options={'mode': 'insert', - "format": "test", - "mergeSchema": "true"}) - - def test_write_error_stats_exception(_fixture_employee, _fixture_writer): with pytest.raises(SparkExpectationsMiscException, @@ -837,14 +858,12 @@ def test_write_error_records_final_exception(_fixture_employee, match=r"error occurred while saving data into the final error table .*"): _fixture_writer.write_error_records_final(_fixture_dq_dataset, "employee_table", - "row_dq", - options={'mode': 'insert', - "format": "test", - "mergeSchema": "true"}) + "row_dq") + def test_generate_rules_exceeds_threshold_exception(): - _context = SparkExpectationsContext("product1") - _writer = SparkExpectationsWriter("product1", _context) + _context = SparkExpectationsContext("product1", spark) + _writer = SparkExpectationsWriter(_context) _context.set_summarised_row_dq_res([{}]) _context.set_input_count(100) diff --git a/tests/utils/test_actions.py b/tests/utils/test_actions.py index 603e31cb..4d799f10 100644 --- a/tests/utils/test_actions.py +++ b/tests/utils/test_actions.py @@ -1,4 +1,3 @@ -import os from unittest.mock import Mock from unittest.mock import patch @@ -31,6 +30,7 @@ def fixture_mock_context(): # fixture for mock context mock_object = Mock(spec=SparkExpectationsContext) mock_object.product_id = "product1" + mock_object.spark=spark mock_object.get_row_dq_rule_type_name = "row_dq" mock_object.get_agg_dq_rule_type_name = "agg_dq" mock_object.get_query_dq_rule_type_name = "query_dq" @@ -282,10 +282,6 @@ def compare_result(_actual_output, _expected_output): {"action_if_failed": "ignore", "description": "desc2"}, {"action_if_failed": "ignore", "description": "desc3"}], ), # input_df - (None, - "agg_dq", # rule_type_name - # expected_output - None), # input_df (spark.createDataFrame( [{"meta_query_dq_results": [{"action_if_failed": "ignore", "description": "desc1"}, @@ -294,28 +290,38 @@ def compare_result(_actual_output, _expected_output): # expected_output "query_dq", # rule_type_name [{'action_if_failed': 'ignore', 'description': 'desc1'}] + ), + (spark.createDataFrame( + [{"meta_query_dq_results": [{"action_if_failed": "ignore", "description": "desc1"}, + ] + }]), + # expected_output + "row_dq", # rule_type_name + None ) ]) def test_create_agg_dq_results(input_df, rule_type_name, - expected_output, - _fixture_mock_context): + expected_output, _fixture_mock_context): # unit test case on create_agg_dq_results - assert SparkExpectationsActions().create_agg_dq_results(_fixture_mock_context, - input_df, rule_type_name) == expected_output + assert SparkExpectationsActions().create_agg_dq_results(_fixture_mock_context,input_df, rule_type_name, ) == expected_output @pytest.mark.parametrize("input_df", [(spark.createDataFrame( [{"agg_dq_results": ""}]), - - )]) + ), + (None, + "agg_dq", # rule_type_name + # expected_output + None), + ]) def test_create_agg_dq_results_exception(input_df, _fixture_mock_context): # faulty user input is given to test the exception functionality of the agg_dq_result with pytest.raises(SparkExpectationsMiscException, match=r"error occurred while running create agg dq results .*"): - SparkExpectationsActions().create_agg_dq_results(_fixture_mock_context, input_df, "") + SparkExpectationsActions().create_agg_dq_results(_fixture_mock_context, input_df, "", ) def test_run_dq_rules_row(_fixture_df, @@ -863,7 +869,7 @@ def test_run_dq_rules_negative_case(_fixture_df, _fixture_mock_context): "query_dq", # rule type False, # row_dq_flag False, # source_agg_flag - True, # final_agg_flag + False, # final_agg_flag True, # source_query_dq_flag False, # final_query_dq_flag SparkExpectationsMiscException # expected df @@ -886,7 +892,7 @@ def test_run_dq_rules_negative_case(_fixture_df, _fixture_mock_context): "query_dq", # rule type False, # row_dq_flag False, # source_agg_flag - True, # final_agg_flag + False, # final_agg_flag False, # source_query_dq_flag True, # final_query_dq_flag SparkExpectationsMiscException # expected df @@ -909,7 +915,7 @@ def test_run_dq_rules_negative_case(_fixture_df, _fixture_mock_context): "query_dq", # rule type False, # row_dq_flag False, # source_agg_flag - True, # final_agg_flag + False, # final_agg_flag False, # source_query_dq_flag True, # final_query_dq_flag spark.createDataFrame( @@ -941,7 +947,7 @@ def test_run_dq_rules_negative_case(_fixture_df, _fixture_mock_context): "query_dq", # rule type False, # row_dq_flag False, # source_agg_flag - True, # final_agg_flag + False, # final_agg_flag False, # source_query_dq_flag True, # final_query_dq_flag spark.createDataFrame( @@ -971,7 +977,6 @@ def test_action_on_dq_rules(_mock_set_row_dq_status, _mock_set_source_agg_dq_sta with pytest.raises(expected_output, match=r"error occured while taking action on given rules .*"): SparkExpectationsActions.action_on_rules(_fixture_mock_context, input_df, - table_name, input_count, error_count, output_count, @@ -1000,7 +1005,6 @@ def test_action_on_dq_rules(_mock_set_row_dq_status, _mock_set_source_agg_dq_sta # assert when all condition passes without action_if_failed "fail" df = SparkExpectationsActions.action_on_rules(_fixture_mock_context, input_df, - table_name, input_count, error_count, output_count, diff --git a/tests/utils/test_reader.py b/tests/utils/test_reader.py index dedcc780..bbca1d84 100644 --- a/tests/utils/test_reader.py +++ b/tests/utils/test_reader.py @@ -1,7 +1,7 @@ import os -import json -from unittest.mock import patch +from unittest.mock import patch, Mock import pytest + # from pytest_mock import mocker // this will be automatically used while running using py-test from spark_expectations.core import get_spark_session from spark_expectations.utils.reader import SparkExpectationsReader @@ -18,26 +18,16 @@ @patch("spark_expectations.utils.reader.SparkExpectationsContext") def fixture_reader(_mocker_context): product_id = 'product1' - return SparkExpectationsReader( - product_id, _mocker_context - ) - - -@pytest.fixture(name="_fixture_reader_file") -def fixture_reader_file(): - context = SparkExpectationsContext("product_1") - context.set_config_file_path(os.path.dirname(__file__), - os.path.join(os.path.dirname(__file__), - "../resources/config/dq_spark_expectations_config.ini")) - return SparkExpectationsReader("product_1", context) + _mocker_context.spark = spark + return SparkExpectationsReader(_mocker_context) @pytest.fixture(name="_fixture_product_rules_view") def fixture_product_rules(): df = ( spark.read.option("header", "true") - .option("inferSchema", "true") - .csv(os.path.join(os.path.dirname(__file__), "../resources/product_rules.csv")) + .option("inferSchema", "true") + .csv(os.path.join(os.path.dirname(__file__), "../resources/product_rules.csv")) ) # Set up the mock dataframe as a temporary table @@ -67,7 +57,6 @@ def fixture_product_rules(): "spark.expectations.notifications.email.subject": "", }, SparkExpectationsMiscException), - ({ "spark.expectations.notifications.email.enabled": False, "spark.expectations.notifications.email.smtp_host": "smtp.mail.com", @@ -84,13 +73,14 @@ def fixture_product_rules(): "spark.expectations.notifications.slack.webhook_url": "", }, SparkExpectationsMiscException), ]) -@patch("spark_expectations.utils.reader.SparkExpectationsContext", autospec=True, spec_set=True) -def test_set_notification_param(mock_context, notification, expected_result): +def test_set_notification_param(notification, expected_result): # This function helps/implements test cases for while setting notification # configurations + mock_context = Mock(spec=SparkExpectationsContext) + mock_context.spark = spark # Create an instance of the class and set the product_id - reader_handler = SparkExpectationsReader("product1", mock_context) + reader_handler = SparkExpectationsReader(mock_context) if expected_result is None: assert reader_handler.set_notification_param(notification) == expected_result @@ -121,29 +111,32 @@ def test_set_notification_param(mock_context, notification, expected_result): @pytest.mark.usefixtures("_fixture_product_rules_view") -@pytest.mark.parametrize("product_id, table_name, action, tag, expected_output", [ - ("product1", "table1", ["fail", "drop"], "tag2", {"rule2": "expectation2"}), - ("product2", "table1", ["drop", "ignore"], None, {"rule5": "expectation5", 'rule12': 'expectation12'}), - ("product1", "table1", ["fail", "drop", "ignore"], None, +@pytest.mark.parametrize("product_id, table_name, tag, expected_output", [ + ("product1", "table1", "tag2", {"rule2": "expectation2"}), + ("product2", "table1", None, {"rule5": "expectation5", "rule7": "expectation7", 'rule12': 'expectation12'}), + ("product1", "table1", None, {"rule1": "expectation1", "rule2": "expectation2", "rule3": "expectation3", "rule6": "expectation6", 'rule10': 'expectation10', 'rule13': 'expectation13'}), - ("product2", "table2", ["fail", "drop", "ignore"], "tag7", {}) + ("product2", "table2", "tag7", {}) ]) -def test_get_rules_dlt(product_id, table_name, action, tag, expected_output, mocker, _fixture_product_rules_view): +def test_get_rules_dlt(product_id, table_name, tag, expected_output, mocker, _fixture_product_rules_view): # create mock _context object mock_context = mocker.MagicMock() + mock_context.spark = spark + mock_context.product_id = product_id # Create an instance of the class and set the product_id - reade_handler = SparkExpectationsReader(product_id, mock_context) - rules_dlt = reade_handler.get_rules_dlt("product_rules", table_name, action, tag) + reade_handler = SparkExpectationsReader(mock_context) + rules_dlt, rules_settings = reade_handler.get_rules_from_df(spark.sql("select * from product_rules"), table_name, + True, tag) # Assert assert rules_dlt == expected_output @pytest.mark.usefixtures("_fixture_product_rules_view") -@pytest.mark.parametrize("product_id, table_name, action, expected_output", [ - ("product1", "table1", ["fail", "drop"], { +@pytest.mark.parametrize("product_id, table_name, expected_expectations, expected_rule_execution_settings", [ + ("product1", "table1", { "target_table_name": "table1", "row_dq_rules": [ { @@ -175,6 +168,21 @@ def test_get_rules_dlt(product_id, table_name, action, tag, expected_output, moc "description": "description2", "enable_error_drop_alert": False, "error_drop_threshold": 0 + }, + { + 'action_if_failed': 'ignore', + 'column_name': 'column3', + 'description': 'description3', + 'enable_error_drop_alert': False, + 'enable_for_source_dq_validation': True, + 'enable_for_target_dq_validation': True, + 'error_drop_threshold': 0, + 'expectation': 'expectation3', + 'product_id': 'product1', + 'rule': 'rule3', + 'rule_type': 'row_dq', + 'table_name': 'table1', + 'tag': 'tag3' } ], "agg_dq_rules": [ @@ -192,6 +200,21 @@ def test_get_rules_dlt(product_id, table_name, action, tag, expected_output, moc "description": "description6", "enable_error_drop_alert": False, "error_drop_threshold": 0 + }, + { + 'action_if_failed': 'ignore', + 'column_name': 'column7', + 'description': 'description10', + 'enable_error_drop_alert': False, + 'enable_for_source_dq_validation': False, + 'enable_for_target_dq_validation': True, + 'error_drop_threshold': 0, + 'expectation': 'expectation10', + 'product_id': 'product1', + 'rule': 'rule10', + 'rule_type': 'agg_dq', + 'table_name': 'table1', + 'tag': 'tag10' } ], "query_dq_rules": [ @@ -211,257 +234,38 @@ def test_get_rules_dlt(product_id, table_name, action, tag, expected_output, moc "error_drop_threshold": 0 } ] - }), - ("product2", "table1", ["fail", "drop"], { - "target_table_name": "table1", - "row_dq_rules": [ - { - "product_id": "product2", - "table_name": "table1", - "rule_type": "row_dq", - "rule": "rule5", - "column_name": "column5", - "expectation": "expectation5", - "action_if_failed": "drop", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, - "tag": "tag5", - "description": "description5", - "enable_error_drop_alert": True, - "error_drop_threshold": 20 - }], - "agg_dq_rules": [ - { - "product_id": "product2", - "table_name": "table1", - "rule_type": "agg_dq", - "rule": "rule7", - "column_name": "column4", - "expectation": "expectation7", - "action_if_failed": "fail", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, - "tag": "tag7", - "description": "description7", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - } - ] - }), - ("product2", "table1", ["ignore"], { - "target_table_name": "table1", - "query_dq_rules": [ - { - "product_id": "product2", - "table_name": "table1", - "rule_type": "query_dq", - "rule": "rule12", - "column_name": "column9", - "expectation": "expectation12", - "action_if_failed": "ignore", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, - "tag": "tag12", - "description": "description12", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - } - ] - }), - ("product1", "table2", ["fail", "ignore"], { - "target_table_name": "table2", - "row_dq_rules": [ - { - "product_id": "product1", - "table_name": "table2", - "rule_type": "row_dq", - "rule": "rule4", - "column_name": "column4", - "expectation": "expectation4", - "action_if_failed": "fail", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, - "tag": "tag4", - "description": "description4", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - } - ], - "agg_dq_rules": [ - { - "product_id": "product1", - "table_name": "table2", - "rule_type": "agg_dq", - "rule": "rule8", - "column_name": "column5", - "expectation": "expectation8", - "action_if_failed": "ignore", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, - "tag": "tag8", - "description": "description8", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - }, - { - "product_id": "product1", - "table_name": "table2", - "rule_type": "agg_dq", - "rule": "rule11", - "column_name": "column8", - "expectation": "expectation11", - "action_if_failed": "ignore", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": False, - "tag": "tag11", - "description": "description11", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - } - ], - "query_dq_rules": [ - { - "product_id": "product1", - "table_name": "table2", - "rule_type": "query_dq", - "rule": "rule14", - "column_name": "column11", - "expectation": "expectation14", - "action_if_failed": "fail", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": False, - "tag": "tag14", - "description": "description14", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - }, - { - "product_id": "product1", - "table_name": "table2", - "rule_type": "query_dq", - "rule": "rule15", - "column_name": "column12", - "expectation": "expectation15", - "action_if_failed": "ignore", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "tag": "tag15", - "description": "description15", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - } - ] - }), - ("product1", "table2", ["fail", "drop"], { - "target_table_name": "table2", - "row_dq_rules": [ - { - "product_id": "product1", - "table_name": "table2", - "rule_type": "row_dq", - "rule": "rule4", - "column_name": "column4", - "expectation": "expectation4", - "action_if_failed": "fail", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, - "tag": "tag4", - "description": "description4", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - } - ], - "query_dq_rules": [ - { - "product_id": "product1", - "table_name": "table2", - "rule_type": "query_dq", - "rule": "rule14", - "column_name": "column11", - "expectation": "expectation14", - "action_if_failed": "fail", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": False, - "tag": "tag14", - "description": "description14", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - } - ] - }), - ("product1", "table2", ["drop", "ignore"], { - "target_table_name": "table2", - "agg_dq_rules": [ - { - "product_id": "product1", - "table_name": "table2", - "rule_type": "agg_dq", - "rule": "rule8", - "column_name": "column5", - "expectation": "expectation8", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": True, - "action_if_failed": "ignore", - "tag": "tag8", - "description": "description8", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - }, - { - "product_id": "product1", - "table_name": "table2", - "rule_type": "agg_dq", - "rule": "rule11", - "column_name": "column8", - "expectation": "expectation11", - "action_if_failed": "ignore", - "enable_for_source_dq_validation": True, - "enable_for_target_dq_validation": False, - "tag": "tag11", - "description": "description11", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - } - ], - "query_dq_rules": [ - { - "product_id": "product1", - "table_name": "table2", - "rule_type": "query_dq", - "rule": "rule15", - "column_name": "column12", - "expectation": "expectation15", - "action_if_failed": "ignore", - "enable_for_source_dq_validation": False, - "enable_for_target_dq_validation": True, - "tag": "tag15", - "description": "description15", - "enable_error_drop_alert": False, - "error_drop_threshold": 0 - } - ] - - }) + }, { + # should be the output of the _get_rules_execution_settings from reader.py + "row_dq": True, + "source_agg_dq": True, + "target_agg_dq": True, + "source_query_dq": True, + "target_query_dq": False, + }) ]) -@patch("spark_expectations.utils.reader.SparkExpectationsContext", autospec=True, spec_set=True) -def test_get_rules_from_table(mock_context, product_id, table_name, - action, expected_output, _fixture_product_rules_view): +def test_get_rules_from_table(product_id, table_name, + expected_expectations, expected_rule_execution_settings, + _fixture_product_rules_view): # Create an instance of the class and set the product_id + mock_context = Mock(spec=SparkExpectationsContext) setattr(mock_context, "get_row_dq_rule_type_name", "row_dq") setattr(mock_context, "get_agg_dq_rule_type_name", "agg_dq") setattr(mock_context, "get_query_dq_rule_type_name", "query_dq") + mock_context.spark = spark + mock_context.product_id=product_id - reader_handler = SparkExpectationsReader(product_id, mock_context) + reader_handler = SparkExpectationsReader(mock_context) - result_dict = reader_handler.get_rules_from_table("product_rules", "test_dq_stats_table", - table_name, - action) + expectations, rule_execution_settings = reader_handler.get_rules_from_df(spark.sql(" select * from product_rules"), + table_name, + is_dlt=False + ) # Assert - assert result_dict == expected_output + assert expectations == expected_expectations + assert rule_execution_settings == expected_rule_execution_settings - mock_context.set_dq_stats_table_name.assert_called_once_with("test_dq_stats_table") mock_context.set_final_table_name.assert_called_once_with(table_name) mock_context.set_error_table_name.assert_called_once_with(f"{table_name}_error") @@ -473,14 +277,12 @@ def test_set_notification_param_exception(_fixture_reader): def test_get_rules_dlt_exception(_fixture_reader): - with pytest.raises(SparkExpectationsUserInputOrConfigInvalidException, - match=r"error occurred while reading or getting rules from the rules table .*"): - _fixture_reader.get_rules_dlt("product_rules_1", "table1", ["fail", "drop"]) + with pytest.raises(SparkExpectationsMiscException, + match=r"error occurred while retrieving rules list .*"): + _fixture_reader.get_rules_from_df("product_rules_1", "table1", is_dlt=True, tag=None) def test_get_rules_from_table_exception(_fixture_reader): with pytest.raises(SparkExpectationsMiscException, - match=r"error occurred while retrieving rules list from the table .*"): - _fixture_reader.get_rules_from_table("mock_rules_table_1", - "mock_dq_stats_table", - "table1", ["fail", "drop"]) + match=r"error occurred while retrieving rules list .*"): + _fixture_reader.get_rules_from_df("mock_rules_table_1", "table1", ) diff --git a/tests/utils/test_regulate_flow.py b/tests/utils/test_regulate_flow.py index dc6e76da..3b758671 100644 --- a/tests/utils/test_regulate_flow.py +++ b/tests/utils/test_regulate_flow.py @@ -5,6 +5,7 @@ from pyspark.sql.functions import lit from spark_expectations.core import get_spark_session from spark_expectations.core.context import SparkExpectationsContext +from spark_expectations.core.expectations import WrappedDataFrameWriter from spark_expectations.core.exceptions import ( SparkExpectationsMiscException ) @@ -18,7 +19,7 @@ @pytest.fixture(name="_fixture_context") def fixture_mock_context(): # fixture for context - sparkexpectations_context = SparkExpectationsContext("product1") + sparkexpectations_context = SparkExpectationsContext("product1", spark) sparkexpectations_context._row_dq_rule_type_name = "row_dq" sparkexpectations_context._agg_dq_rule_type_name = "agg_dq" @@ -27,6 +28,9 @@ def fixture_mock_context(): sparkexpectations_context._run_date = "2022-12-27 10:39:44" sparkexpectations_context._run_id = "product1_run_test" sparkexpectations_context._input_count = 10 + writer = WrappedDataFrameWriter().mode('overwrite').format('delta').build() + sparkexpectations_context.set_target_and_error_table_writer_config(writer) + sparkexpectations_context.set_stats_table_writer_config(writer) return sparkexpectations_context @@ -1574,13 +1578,8 @@ def test_execute_dq_process(_mock_notify, _fixture_context, _fixture_create_stats_table): spark.conf.set("spark.sql.session.timeZone", "Etc/UTC") - spark_conf = {"spark.sql.session.timeZone": "Etc/UTC"} - options = {'mode': 'overwrite', "format": "delta"} - options_error_table = {'mode': 'overwrite', "format": "delta"} - df.createOrReplaceTempView("test_table") - - writer = SparkExpectationsWriter("product1", _fixture_context) + writer = SparkExpectationsWriter(_fixture_context) regulate_flow = SparkExpectationsRegulateFlow("product1") func_process = regulate_flow.execute_dq_process( @@ -1590,9 +1589,7 @@ def test_execute_dq_process(_mock_notify, _mock_notify, expectations, "dq_spark.test_final_table", - input_count, - spark_conf, - options_error_table + input_count ) # assert if expected output raises certain exception for failure @@ -1800,13 +1797,9 @@ def test_execute_dq_process_exception(df, with pytest.raises(SparkExpectationsMiscException, match=r"error occurred while executing func_process .*"): mock_contextt = Mock(spec=SparkExpectationsContext) + mock_contextt.spark = spark actions = SparkExpectationsActions() - writer = SparkExpectationsWriter("product1", mock_contextt) - - spark_conf = {"spark.sql.session.timeZone": "Etc/UTC"} - options = {'mode': 'overwrite', "format": "delta"} - options_error_table = {'mode': 'overwrite', "format": "delta"} - + writer = SparkExpectationsWriter(mock_contextt) regulate_flow = SparkExpectationsRegulateFlow("product1") func_process = regulate_flow.execute_dq_process( mock_contextt, @@ -1814,9 +1807,7 @@ def test_execute_dq_process_exception(df, writer, expectations, "dq_spark.test_final_table", - input_count, - spark_conf, - options_error_table + input_count ) (_df, _agg_dq_res, _error_count, _status) = func_process(df,