From a6c8cd833df5b4bc594e5f49df0dba82829b4e2c Mon Sep 17 00:00:00 2001 From: Phil Snyder Date: Tue, 7 May 2024 16:49:25 -0700 Subject: [PATCH] Update `drop_table_duplicates` function to use window + sort + rank + filter approach --- src/glue/jobs/json_to_parquet.py | 114 ++++++++++++++++--------------- tests/test_json_to_parquet.py | 25 ++++--- 2 files changed, 75 insertions(+), 64 deletions(-) diff --git a/src/glue/jobs/json_to_parquet.py b/src/glue/jobs/json_to_parquet.py index d61c86fd..7ebda420 100644 --- a/src/glue/jobs/json_to_parquet.py +++ b/src/glue/jobs/json_to_parquet.py @@ -24,6 +24,9 @@ from awsglue.gluetypes import StructType from awsglue.utils import getResolvedOptions from pyspark import SparkContext +from pyspark.sql import Window +from pyspark.sql.functions import row_number, col +from pyspark.sql.dataframe import DataFrame # Configure logger to use ECS formatting logger = logging.getLogger(__name__) @@ -122,7 +125,7 @@ def get_table( glue_context: GlueContext, record_counts: dict, logger_context: dict, -) -> DynamicFrame: + ) -> DynamicFrame: """ Return a table as a DynamicFrame with an unambiguous schema. Additionally, we drop any superfluous partition_* fields which are added by Glue. @@ -162,11 +165,10 @@ def get_table( def drop_table_duplicates( table: DynamicFrame, - table_name: str, - glue_context: GlueContext, + data_type: str, record_counts: dict[str, list], logger_context: dict, -) -> DynamicFrame: + ) -> DataFrame: """ Drop duplicate samples and superflous partition columns. @@ -177,49 +179,52 @@ def drop_table_duplicates( Args: table (DynamicFrame): The table from which to drop duplicates - table_name (str): The name of the Glue table. Helps determine + data_type (str): The data type. Helps determine the table index in conjunction with INDEX_FIELD_MAP. - glue_context (GlueContext): The glue context record_counts (dict[str,list]): A dict mapping data types to a list of counts, each of which corresponds to an `event` type. logger_context (dict): A dictionary containing contextual information to include with every log. Returns: - awsglue.DynamicFrame: The `table` DynamicFrame after duplicates have been dropped. + pyspark.sql.dataframe.DataFrame: A Spark dataframe with duplicates removed. """ - table_name_components = table_name.split("_") - table_data_type = table_name_components[1] + window_unordered = Window.partitionBy(INDEX_FIELD_MAP[data_type]) spark_df = table.toDF() if "InsertedDate" in spark_df.columns: - sorted_spark_df = spark_df.sort( - [spark_df.InsertedDate.desc(), spark_df.export_end_date.desc()] + window_ordered = window_unordered.orderBy( + col("InsertedDate").desc(), + col("export_end_date").desc() ) else: - sorted_spark_df = spark_df.sort(spark_df.export_end_date.desc()) - table = DynamicFrame.fromDF( - dataframe=( - sorted_spark_df.dropDuplicates(subset=INDEX_FIELD_MAP[table_data_type]) - ), - glue_ctx=glue_context, - name=table_name, + window_ordered = window_unordered.orderBy( + col("export_end_date").desc() + ) + table_no_duplicates = ( + spark_df + .withColumn('ranking', row_number().over(window_ordered)) + .filter("ranking == 1") + .drop("ranking") + .cache() ) count_records_for_event( - table=table.toDF(), + table=table_no_duplicates, event=CountEventType.DROP_DUPLICATES, record_counts=record_counts, logger_context=logger_context, ) - return table + return table_no_duplicates def drop_deleted_healthkit_data( glue_context: GlueContext, - table: DynamicFrame, + table: DataFrame, + table_name: str, + data_type: str, glue_database: str, record_counts: dict[str, list], logger_context: dict, -) -> DynamicFrame: + ) -> DataFrame: """ Drop records from a HealthKit table. @@ -244,15 +249,15 @@ def drop_deleted_healthkit_data( samples removed. """ glue_client = boto3.client("glue") - deleted_table_name = f"{table.name}_deleted" - table_data_type = table.name.split("_")[1] + deleted_table_name = f"{table_name}_deleted" + deleted_data_type = f"{data_type}_deleted" try: glue_client.get_table(DatabaseName=glue_database, Name=deleted_table_name) except glue_client.exceptions.EntityNotFoundException: return table deleted_table_logger_context = deepcopy(logger_context) deleted_table_logger_context["labels"]["glue_table_name"] = deleted_table_name - deleted_table_logger_context["labels"]["type"] = f"{table_data_type}_deleted" + deleted_table_logger_context["labels"]["type"] = deleted_data_type deleted_table_raw = get_table( table_name=deleted_table_name, database_name=glue_database, @@ -260,28 +265,21 @@ def drop_deleted_healthkit_data( record_counts=record_counts, logger_context=deleted_table_logger_context, ) + # we use `data_type` rather than `deleted_data_type` here because they share + # an index (we don't bother including `deleted_data_type` in `INDEX_FIELD_MAP`). deleted_table = drop_table_duplicates( table=deleted_table_raw, - table_name=deleted_table_name, - glue_context=glue_context, + data_type=data_type, record_counts=record_counts, logger_context=deleted_table_logger_context, ) - table_df = table.toDF() - deleted_table_df = deleted_table.toDF() - table_with_deleted_samples_removed = DynamicFrame.fromDF( - dataframe=( - table_df.join( - other=deleted_table_df, - on=INDEX_FIELD_MAP[table_data_type], + table_with_deleted_samples_removed = table.join( + other=deleted_table, + on=INDEX_FIELD_MAP[data_type], how="left_anti", - ) - ), - glue_ctx=glue_context, - name=table.name, ) count_records_for_event( - table=table_with_deleted_samples_removed.toDF(), + table=table_with_deleted_samples_removed, event=CountEventType.DROP_DELETED_SAMPLES, record_counts=record_counts, logger_context=logger_context, @@ -295,7 +293,7 @@ def archive_existing_datasets( workflow_name: str, workflow_run_id: str, delete_upon_completion: bool, -) -> list[dict]: + ) -> list[dict]: """ Archives existing datasets in S3 by copying them to a timestamped subfolder within an "archive" folder. The format of the timestamped subfolder is: @@ -363,7 +361,7 @@ def write_table_to_s3( workflow_name: str, workflow_run_id: str, records_per_partition: int = int(1e6), -) -> None: + ) -> None: """ Write a DynamicFrame to S3 as a parquet dataset. @@ -432,11 +430,11 @@ class CountEventType(Enum): def count_records_for_event( - table: "pyspark.sql.dataframe.DataFrame", + table: DataFrame, event: CountEventType, record_counts: dict[str, list], logger_context: dict, -) -> dict[str, list]: + ) -> dict[str, list]: """ Compute record count statistics for each `export_end_date`. @@ -483,7 +481,7 @@ def store_record_counts( namespace: str, workflow_name: str, workflow_run_id: str, -) -> dict[str, str]: + ) -> dict[str, str]: """ Uploads record counts as S3 objects. @@ -529,7 +527,7 @@ def add_index_to_table( table_name: str, processed_tables: dict[str, DynamicFrame], unprocessed_tables: dict[str, DynamicFrame], -) -> "pyspark.sql.dataframe.DataFrame": + ) -> DataFrame: """Add partition and index fields to a DynamicFrame. A DynamicFrame containing the top-level fields already includes the index @@ -618,10 +616,12 @@ def main() -> None: # Get args and setup environment args, workflow_run_properties = get_args() glue_context = GlueContext(SparkContext.getOrCreate()) + table_name = args["glue_table"] + data_type = args["glue_table"].split("_")[1] logger_context = { "labels": { - "glue_table_name": args["glue_table"], - "type": args["glue_table"].split("_")[1], + "glue_table_name": table_name, + "type": data_type, "job_name": args["JOB_NAME"], }, "process.parent.pid": args["WORKFLOW_RUN_ID"], @@ -634,7 +634,6 @@ def main() -> None: job.init(args["JOB_NAME"], args) # Read table and drop duplicated and deleted samples - table_name = args["glue_table"] table_raw = get_table( table_name=table_name, database_name=workflow_run_properties["glue_database"], @@ -646,8 +645,7 @@ def main() -> None: return table = drop_table_duplicates( table=table_raw, - table_name=table_name, - glue_context=glue_context, + data_type=data_type, record_counts=record_counts, logger_context=logger_context, ) @@ -655,15 +653,21 @@ def main() -> None: table = drop_deleted_healthkit_data( glue_context=glue_context, table=table, + table_name=table_name, + data_type=data_type, glue_database=workflow_run_properties["glue_database"], record_counts=record_counts, logger_context=logger_context, ) - + table_dynamic = DynamicFrame.fromDF( + dataframe=table, + glue_ctx=glue_context, + name=table_name + ) # Export new table records to parquet - if has_nested_fields(table.schema()): + if has_nested_fields(table.schema): tables_with_index = {} - table_relationalized = table.relationalize( + table_relationalized = table_dynamic.relationalize( root_table_name=table_name, staging_path=f"s3://{workflow_run_properties['parquet_bucket']}/tmp/", transformation_ctx="relationalize", @@ -702,7 +706,7 @@ def main() -> None: ) else: write_table_to_s3( - dynamic_frame=table, + dynamic_frame=table_dynamic, bucket=workflow_run_properties["parquet_bucket"], key=os.path.join( workflow_run_properties["namespace"], @@ -714,7 +718,7 @@ def main() -> None: glue_context=glue_context, ) count_records_for_event( - table=table.toDF(), + table=table_dynamic.toDF(), event=CountEventType.WRITE, record_counts=record_counts, logger_context=logger_context, diff --git a/tests/test_json_to_parquet.py b/tests/test_json_to_parquet.py index eba5feea..60512e28 100644 --- a/tests/test_json_to_parquet.py +++ b/tests/test_json_to_parquet.py @@ -121,6 +121,10 @@ def create_table( ) return table +@pytest.fixture() +def flat_data_type(glue_flat_table_name): + flat_data_type = glue_flat_table_name.split("_")[1] + return flat_data_type @pytest.fixture() def sample_table( @@ -654,29 +658,29 @@ def test_get_table_nested( ["partition_" in field.name for field in nested_table.schema().fields] ) - def test_drop_table_duplicates(self, sample_table, glue_context, logger_context): + def test_drop_table_duplicates( + self, sample_table, flat_data_type, glue_context, logger_context + ): table_no_duplicates = json_to_parquet.drop_table_duplicates( table=sample_table, - table_name=sample_table.name, - glue_context=glue_context, + data_type=flat_data_type, record_counts=defaultdict(list), logger_context=logger_context, ) - table_no_duplicates_df = table_no_duplicates.toDF().toPandas() + table_no_duplicates_df = table_no_duplicates.toPandas() assert len(table_no_duplicates_df) == 3 assert "Chicago" in set(table_no_duplicates_df.city) def test_drop_table_duplicates_inserted_date( - self, sample_table_inserted_date, glue_context, logger_context + self, sample_table_inserted_date, flat_data_type, glue_context, logger_context ): table_no_duplicates = json_to_parquet.drop_table_duplicates( table=sample_table_inserted_date, - table_name=sample_table_inserted_date.name, - glue_context=glue_context, + data_type=flat_data_type, record_counts=defaultdict(list), logger_context=logger_context, ) - table_no_duplicates_df = table_no_duplicates.toDF().toPandas() + table_no_duplicates_df = table_no_duplicates.toPandas() assert len(table_no_duplicates_df) == 3 assert set(["John", "Jane", "Bob_2"]) == set( table_no_duplicates_df["name"].tolist() @@ -923,6 +927,7 @@ def test_drop_deleted_healthkit_data( self, glue_context, glue_flat_table_name, + flat_data_type, glue_database_name, logger_context, ): @@ -935,7 +940,9 @@ def test_drop_deleted_healthkit_data( ) table_after_drop = json_to_parquet.drop_deleted_healthkit_data( glue_context=glue_context, - table=table, + table=table.toDF(), + table_name=table.name, + data_type=flat_data_type, glue_database=glue_database_name, record_counts=defaultdict(list), logger_context=logger_context,