diff --git a/src/lambda_function/raw_sync/app.py b/src/lambda_function/raw_sync/app.py index 4e46101..aa01232 100644 --- a/src/lambda_function/raw_sync/app.py +++ b/src/lambda_function/raw_sync/app.py @@ -54,7 +54,7 @@ def lambda_handler(event: dict, context: dict) -> None: ) -def append_s3_key(key: str, key_format: str, result: dict) -> None: +def append_s3_key(key: str, key_format: str, result: defaultdict) -> defaultdict: """ Organizes an S3 object key by appending it to the appropriate entry in the result dictionary @@ -68,7 +68,7 @@ def append_s3_key(key: str, key_format: str, result: dict) -> None: it is structured as result[cohort]. Returns: - None + defaultdict: The `result` dict with `key` added. """ result = result.copy() # shallow copy safe for append if not key.endswith("/"): # Ignore keys that represent "folders" @@ -97,7 +97,7 @@ def append_s3_key(key: str, key_format: str, result: dict) -> None: def list_s3_objects( s3_client: boto3.client, bucket: str, key_prefix: str, key_format: str -) -> dict: +) -> defaultdict: """ Recursively list all objects under an S3 bucket and key prefix which conform to a specified format. @@ -159,6 +159,8 @@ def list_s3_objects( result = defaultdict(lambda: defaultdict(list)) elif key_format == "input": result = defaultdict(list) + else: + raise ValueError("Argument `key_format` must be either 'input' or 'raw'.") for response in response_iterator: for obj in response.get("Contents", []): key = obj["Key"] @@ -232,7 +234,7 @@ def parse_content_range(content_range: str) -> tuple[int, ...]: return range_start, range_end, total_size -def unpack_eocd_fields(body: bytes, eocd_offset: int) -> list[int]: +def unpack_eocd_fields(body: bytes, eocd_offset: int) -> tuple[int, int]: """ Extract the End of Central Directory (EOCD) fields from the given body. @@ -305,7 +307,7 @@ def determine_eocd_offset(body: bytes, content_range: str) -> int: def list_files_in_archive( s3_client: boto3.client, bucket: str, key: str, range_size=64 * 1024 -) -> list[str]: +) -> list[dict]: """ Recursively lists files in a ZIP archive stored as an S3 object. @@ -464,6 +466,23 @@ def publish_to_sns( sns_client.publish(TopicArn=sns_arn, Message=json.dumps(file_info)) +def get_data_type_from_path(path: str) -> str: + """ + Give the path of an export file, return its associated data type + + Args: + path (str): The path of an export file + + Returns: + data_type (str): The data type + """ + basename = os.path.basename(path) + data_type = basename.split("_")[0] + if "Deleted" in basename: + data_type = f"{data_type}_Deleted" + return data_type + + def main( event: dict, s3_client: boto3.client, @@ -499,7 +518,7 @@ def main( f"Checking corresponding raw object for {filename} " f"from s3://{input_bucket}/{export_key}" ) - data_type = filename.split("_")[0] + data_type = get_data_type_from_path(path=filename) file_identifier = filename.split(".")[0] expected_key = ( f"{namespace}/json/dataset={data_type}" diff --git a/tests/test_lambda_raw_sync.py b/tests/test_lambda_raw_sync.py index 227db50..58daa7d 100644 --- a/tests/test_lambda_raw_sync.py +++ b/tests/test_lambda_raw_sync.py @@ -612,3 +612,24 @@ def test_publish_to_sns_with_sqs_subscription(): sqs_client.delete_message( QueueUrl=sqs_url, ReceiptHandle=messages["Messages"][0]["ReceiptHandle"] ) + + +def test_get_data_type_from_path_simple(): + """Test a path with no subtype or 'Deleted' component""" + path = "path/to/FitbitIntradayCombined_20241111-20241112.json" + data_type = app.get_data_type_from_path(path=path) + assert data_type == "FitbitIntradayCombined" + + +def test_get_data_type_from_path_subtype(): + """Test a path with a subtype""" + path = "path/to/HealthKitV2Samples_AppleStandTime_20241111-20241112.json" + data_type = app.get_data_type_from_path(path=path) + assert data_type == "HealthKitV2Samples" + + +def test_get_data_type_from_path_deleted(): + """Test a path with a subtype and a 'Deleted' component""" + path = "path/to/HealthKitV2Samples_AppleStandTime_Deleted_20241111-20241112.json" + data_type = app.get_data_type_from_path(path=path) + assert data_type == "HealthKitV2Samples_Deleted"