Skip to content

Commit

Permalink
Update drop_table_duplicates function to use window + sort + rank +…
Browse files Browse the repository at this point in the history
… filter approach (#113)
  • Loading branch information
philerooski authored May 8, 2024
1 parent f312f6a commit 5216f97
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 64 deletions.
114 changes: 59 additions & 55 deletions src/glue/jobs/json_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from awsglue.gluetypes import StructType
from awsglue.utils import getResolvedOptions
from pyspark import SparkContext
from pyspark.sql import Window
from pyspark.sql.functions import row_number, col
from pyspark.sql.dataframe import DataFrame

# Configure logger to use ECS formatting
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -122,7 +125,7 @@ def get_table(
glue_context: GlueContext,
record_counts: dict,
logger_context: dict,
) -> DynamicFrame:
) -> DynamicFrame:
"""
Return a table as a DynamicFrame with an unambiguous schema. Additionally,
we drop any superfluous partition_* fields which are added by Glue.
Expand Down Expand Up @@ -162,11 +165,10 @@ def get_table(

def drop_table_duplicates(
table: DynamicFrame,
table_name: str,
glue_context: GlueContext,
data_type: str,
record_counts: dict[str, list],
logger_context: dict,
) -> DynamicFrame:
) -> DataFrame:
"""
Drop duplicate samples and superflous partition columns.
Expand All @@ -177,49 +179,52 @@ def drop_table_duplicates(
Args:
table (DynamicFrame): The table from which to drop duplicates
table_name (str): The name of the Glue table. Helps determine
data_type (str): The data type. Helps determine
the table index in conjunction with INDEX_FIELD_MAP.
glue_context (GlueContext): The glue context
record_counts (dict[str,list]): A dict mapping data types to a list
of counts, each of which corresponds to an `event` type.
logger_context (dict): A dictionary containing contextual information
to include with every log.
Returns:
awsglue.DynamicFrame: The `table` DynamicFrame after duplicates have been dropped.
pyspark.sql.dataframe.DataFrame: A Spark dataframe with duplicates removed.
"""
table_name_components = table_name.split("_")
table_data_type = table_name_components[1]
window_unordered = Window.partitionBy(INDEX_FIELD_MAP[data_type])
spark_df = table.toDF()
if "InsertedDate" in spark_df.columns:
sorted_spark_df = spark_df.sort(
[spark_df.InsertedDate.desc(), spark_df.export_end_date.desc()]
window_ordered = window_unordered.orderBy(
col("InsertedDate").desc(),
col("export_end_date").desc()
)
else:
sorted_spark_df = spark_df.sort(spark_df.export_end_date.desc())
table = DynamicFrame.fromDF(
dataframe=(
sorted_spark_df.dropDuplicates(subset=INDEX_FIELD_MAP[table_data_type])
),
glue_ctx=glue_context,
name=table_name,
window_ordered = window_unordered.orderBy(
col("export_end_date").desc()
)
table_no_duplicates = (
spark_df
.withColumn('ranking', row_number().over(window_ordered))
.filter("ranking == 1")
.drop("ranking")
.cache()
)
count_records_for_event(
table=table.toDF(),
table=table_no_duplicates,
event=CountEventType.DROP_DUPLICATES,
record_counts=record_counts,
logger_context=logger_context,
)
return table
return table_no_duplicates


def drop_deleted_healthkit_data(
glue_context: GlueContext,
table: DynamicFrame,
table: DataFrame,
table_name: str,
data_type: str,
glue_database: str,
record_counts: dict[str, list],
logger_context: dict,
) -> DynamicFrame:
) -> DataFrame:
"""
Drop records from a HealthKit table.
Expand All @@ -244,44 +249,37 @@ def drop_deleted_healthkit_data(
samples removed.
"""
glue_client = boto3.client("glue")
deleted_table_name = f"{table.name}_deleted"
table_data_type = table.name.split("_")[1]
deleted_table_name = f"{table_name}_deleted"
deleted_data_type = f"{data_type}_deleted"
try:
glue_client.get_table(DatabaseName=glue_database, Name=deleted_table_name)
except glue_client.exceptions.EntityNotFoundException:
return table
deleted_table_logger_context = deepcopy(logger_context)
deleted_table_logger_context["labels"]["glue_table_name"] = deleted_table_name
deleted_table_logger_context["labels"]["type"] = f"{table_data_type}_deleted"
deleted_table_logger_context["labels"]["type"] = deleted_data_type
deleted_table_raw = get_table(
table_name=deleted_table_name,
database_name=glue_database,
glue_context=glue_context,
record_counts=record_counts,
logger_context=deleted_table_logger_context,
)
# we use `data_type` rather than `deleted_data_type` here because they share
# an index (we don't bother including `deleted_data_type` in `INDEX_FIELD_MAP`).
deleted_table = drop_table_duplicates(
table=deleted_table_raw,
table_name=deleted_table_name,
glue_context=glue_context,
data_type=data_type,
record_counts=record_counts,
logger_context=deleted_table_logger_context,
)
table_df = table.toDF()
deleted_table_df = deleted_table.toDF()
table_with_deleted_samples_removed = DynamicFrame.fromDF(
dataframe=(
table_df.join(
other=deleted_table_df,
on=INDEX_FIELD_MAP[table_data_type],
table_with_deleted_samples_removed = table.join(
other=deleted_table,
on=INDEX_FIELD_MAP[data_type],
how="left_anti",
)
),
glue_ctx=glue_context,
name=table.name,
)
count_records_for_event(
table=table_with_deleted_samples_removed.toDF(),
table=table_with_deleted_samples_removed,
event=CountEventType.DROP_DELETED_SAMPLES,
record_counts=record_counts,
logger_context=logger_context,
Expand All @@ -295,7 +293,7 @@ def archive_existing_datasets(
workflow_name: str,
workflow_run_id: str,
delete_upon_completion: bool,
) -> list[dict]:
) -> list[dict]:
"""
Archives existing datasets in S3 by copying them to a timestamped subfolder
within an "archive" folder. The format of the timestamped subfolder is:
Expand Down Expand Up @@ -363,7 +361,7 @@ def write_table_to_s3(
workflow_name: str,
workflow_run_id: str,
records_per_partition: int = int(1e6),
) -> None:
) -> None:
"""
Write a DynamicFrame to S3 as a parquet dataset.
Expand Down Expand Up @@ -432,11 +430,11 @@ class CountEventType(Enum):


def count_records_for_event(
table: "pyspark.sql.dataframe.DataFrame",
table: DataFrame,
event: CountEventType,
record_counts: dict[str, list],
logger_context: dict,
) -> dict[str, list]:
) -> dict[str, list]:
"""
Compute record count statistics for each `export_end_date`.
Expand Down Expand Up @@ -483,7 +481,7 @@ def store_record_counts(
namespace: str,
workflow_name: str,
workflow_run_id: str,
) -> dict[str, str]:
) -> dict[str, str]:
"""
Uploads record counts as S3 objects.
Expand Down Expand Up @@ -529,7 +527,7 @@ def add_index_to_table(
table_name: str,
processed_tables: dict[str, DynamicFrame],
unprocessed_tables: dict[str, DynamicFrame],
) -> "pyspark.sql.dataframe.DataFrame":
) -> DataFrame:
"""Add partition and index fields to a DynamicFrame.
A DynamicFrame containing the top-level fields already includes the index
Expand Down Expand Up @@ -618,10 +616,12 @@ def main() -> None:
# Get args and setup environment
args, workflow_run_properties = get_args()
glue_context = GlueContext(SparkContext.getOrCreate())
table_name = args["glue_table"]
data_type = args["glue_table"].split("_")[1]
logger_context = {
"labels": {
"glue_table_name": args["glue_table"],
"type": args["glue_table"].split("_")[1],
"glue_table_name": table_name,
"type": data_type,
"job_name": args["JOB_NAME"],
},
"process.parent.pid": args["WORKFLOW_RUN_ID"],
Expand All @@ -634,7 +634,6 @@ def main() -> None:
job.init(args["JOB_NAME"], args)

# Read table and drop duplicated and deleted samples
table_name = args["glue_table"]
table_raw = get_table(
table_name=table_name,
database_name=workflow_run_properties["glue_database"],
Expand All @@ -646,24 +645,29 @@ def main() -> None:
return
table = drop_table_duplicates(
table=table_raw,
table_name=table_name,
glue_context=glue_context,
data_type=data_type,
record_counts=record_counts,
logger_context=logger_context,
)
if "healthkit" in table_name:
table = drop_deleted_healthkit_data(
glue_context=glue_context,
table=table,
table_name=table_name,
data_type=data_type,
glue_database=workflow_run_properties["glue_database"],
record_counts=record_counts,
logger_context=logger_context,
)

table_dynamic = DynamicFrame.fromDF(
dataframe=table,
glue_ctx=glue_context,
name=table_name
)
# Export new table records to parquet
if has_nested_fields(table.schema()):
if has_nested_fields(table.schema):
tables_with_index = {}
table_relationalized = table.relationalize(
table_relationalized = table_dynamic.relationalize(
root_table_name=table_name,
staging_path=f"s3://{workflow_run_properties['parquet_bucket']}/tmp/",
transformation_ctx="relationalize",
Expand Down Expand Up @@ -702,7 +706,7 @@ def main() -> None:
)
else:
write_table_to_s3(
dynamic_frame=table,
dynamic_frame=table_dynamic,
bucket=workflow_run_properties["parquet_bucket"],
key=os.path.join(
workflow_run_properties["namespace"],
Expand All @@ -714,7 +718,7 @@ def main() -> None:
glue_context=glue_context,
)
count_records_for_event(
table=table.toDF(),
table=table_dynamic.toDF(),
event=CountEventType.WRITE,
record_counts=record_counts,
logger_context=logger_context,
Expand Down
25 changes: 16 additions & 9 deletions tests/test_json_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def create_table(
)
return table

@pytest.fixture()
def flat_data_type(glue_flat_table_name):
flat_data_type = glue_flat_table_name.split("_")[1]
return flat_data_type

@pytest.fixture()
def sample_table(
Expand Down Expand Up @@ -654,29 +658,29 @@ def test_get_table_nested(
["partition_" in field.name for field in nested_table.schema().fields]
)

def test_drop_table_duplicates(self, sample_table, glue_context, logger_context):
def test_drop_table_duplicates(
self, sample_table, flat_data_type, glue_context, logger_context
):
table_no_duplicates = json_to_parquet.drop_table_duplicates(
table=sample_table,
table_name=sample_table.name,
glue_context=glue_context,
data_type=flat_data_type,
record_counts=defaultdict(list),
logger_context=logger_context,
)
table_no_duplicates_df = table_no_duplicates.toDF().toPandas()
table_no_duplicates_df = table_no_duplicates.toPandas()
assert len(table_no_duplicates_df) == 3
assert "Chicago" in set(table_no_duplicates_df.city)

def test_drop_table_duplicates_inserted_date(
self, sample_table_inserted_date, glue_context, logger_context
self, sample_table_inserted_date, flat_data_type, glue_context, logger_context
):
table_no_duplicates = json_to_parquet.drop_table_duplicates(
table=sample_table_inserted_date,
table_name=sample_table_inserted_date.name,
glue_context=glue_context,
data_type=flat_data_type,
record_counts=defaultdict(list),
logger_context=logger_context,
)
table_no_duplicates_df = table_no_duplicates.toDF().toPandas()
table_no_duplicates_df = table_no_duplicates.toPandas()
assert len(table_no_duplicates_df) == 3
assert set(["John", "Jane", "Bob_2"]) == set(
table_no_duplicates_df["name"].tolist()
Expand Down Expand Up @@ -923,6 +927,7 @@ def test_drop_deleted_healthkit_data(
self,
glue_context,
glue_flat_table_name,
flat_data_type,
glue_database_name,
logger_context,
):
Expand All @@ -935,7 +940,9 @@ def test_drop_deleted_healthkit_data(
)
table_after_drop = json_to_parquet.drop_deleted_healthkit_data(
glue_context=glue_context,
table=table,
table=table.toDF(),
table_name=table.name,
data_type=flat_data_type,
glue_database=glue_database_name,
record_counts=defaultdict(list),
logger_context=logger_context,
Expand Down

0 comments on commit 5216f97

Please sign in to comment.