diff --git a/e2e/test_batch_processing.py b/e2e/test_batch_processing.py deleted file mode 100644 index 0e093b44e..000000000 --- a/e2e/test_batch_processing.py +++ /dev/null @@ -1,108 +0,0 @@ -import copy -import logging -import unittest -import uuid -from time import sleep - -# import boto3 -# from botocore.config import Config - -# from utils.batch import BatchFile, base_headers, base_record, get_s3_source_name, get_cluster_name, \ -# download_report_file, get_s3_destination_name, CtlFile, CtlData - -from utils.batch import BatchFile, base_headers, base_record, \ - CtlFile, CtlData - -logger = logging.getLogger("batch") -logger.setLevel(logging.DEBUG) -formatter = logging.Formatter('%(message)s') -handler = logging.StreamHandler() -handler.setFormatter(formatter) -logger.addHandler(handler) - - -def _wait_for_batch_processing(ecs_client, cluster: str): - """wait for batch processing to finish in the given cluster by polling the ECS tasks status""" - timeout_seconds = 120 - elapsed = 0 - logger.debug(f"waiting for tasks to finish in cluster: {cluster}") - while elapsed <= timeout_seconds: - response = ecs_client.list_tasks( - cluster=cluster, - desiredStatus='RUNNING' - ) - num_tasks = len(response['taskArns']) - if num_tasks == 0: - break - sleep(1) - elapsed += 1 - if elapsed % 10 == 0: - logger.debug(f"waiting for batch processing to finish, elapsed: {elapsed} seconds") - - if elapsed >= timeout_seconds: - raise TimeoutError("Batch processing tasks did not finish in time") - - -def make_ctl_file() -> CtlFile: - ctl = CtlData(from_dts="a_date", to_dts="another_date") - return CtlFile(ctl) - - -class TestBatchProcessing(unittest.TestCase): - s3_client = None - dat_key = None - ctl_key = None - report_key = None - - # @classmethod - # def setUpClass(cls): - # name = str(uuid.uuid4()) - # cls.dat_key = f"{name}.dat" - # cls.ctl_key = f"{name}.ctl" - # cls.report_key = cls.dat_key - - # session = boto3.Session(profile_name="apim-dev") - # aws_conf = Config(region_name="eu-west-2") - # s3_client = session.client("s3", config=aws_conf) - # cls.s3_client = s3_client - - # source_bucket = get_s3_source_name() - # date_file = cls.make_batch_file() - # # date_file.upload_to_s3(cls.s3_client, source_bucket, cls.dat_key) - # date_file.upload_to_s3(cls.s3_client, source_bucket) - # ctl_file = make_ctl_file() - # # ctl_file.upload_to_s3(cls.s3_client, source_bucket, cls.ctl_key) - # ctl_file.upload_to_s3(cls.s3_client, source_bucket) - - # logger.debug("waiting for the batch processing to finish") - # # wait for the event rule to start the task and then wait for the task to finish - # sleep(5) - # ecs_client = session.client('ecs', config=aws_conf) - # _wait_for_batch_processing(ecs_client, get_cluster_name()) - - @staticmethod - def make_batch_file() -> BatchFile: - # Try to create records for each scenario because, it takes time for uploading - # and waiting for the processing to finish. - bf = BatchFile(headers=base_headers) - logger.debug("creating batch file with headers:") - logger.debug(base_headers) - - record1 = copy.deepcopy(base_record) - record1["UNIQUE_ID"] = str(uuid.uuid4()) - bf.add_record(record1, "happy-path") - - record2 = copy.deepcopy(base_record) - record2["UNIQUE_ID"] = str(uuid.uuid4()) - record2["PERSON_DOB"] = "bad-date" - bf.add_record(record1, "error") - - logger.debug(bf.stream.getvalue().decode("utf-8")) - return bf - - # def test_batch_file(self): - # report = download_report_file(self.s3_client, get_s3_destination_name(), self.report_key) - # logger.debug(f"report:\n{report}") - - # lines = report.splitlines() - # self.assertEqual(len(lines), 1) diff --git a/filenameprocessor/src/audit_table.py b/filenameprocessor/src/audit_table.py index 74294260d..3870faeb8 100644 --- a/filenameprocessor/src/audit_table.py +++ b/filenameprocessor/src/audit_table.py @@ -28,7 +28,7 @@ def add_to_audit_table(message_id: str, file_key: str, created_at_formatted_str: Item={ "message_id": {"S": message_id}, "filename": {"S": file_key}, - "status": {"S": "Processed"}, + "status": {"S": "Not processed - duplicate" if duplicate_exists else "Processed"}, "timestamp": {"S": created_at_formatted_str}, }, ConditionExpression="attribute_not_exists(message_id)", # Prevents accidental overwrites diff --git a/filenameprocessor/src/file_name_processor.py b/filenameprocessor/src/file_name_processor.py index abef4eccb..0e058a059 100644 --- a/filenameprocessor/src/file_name_processor.py +++ b/filenameprocessor/src/file_name_processor.py @@ -77,7 +77,7 @@ def handle_record(record) -> dict: UnhandledSqsError, Exception, ) as error: - logger.error("Error processing file'%s': %s", file_key, str(error)) + logger.error("Error processing file '%s': %s", file_key, str(error)) # Create ack file # (note that error may have occurred before message_id and created_at_formatted_string were generated) diff --git a/filenameprocessor/src/logging_decorator.py b/filenameprocessor/src/logging_decorator.py index 56e883efd..d4e2745ed 100644 --- a/filenameprocessor/src/logging_decorator.py +++ b/filenameprocessor/src/logging_decorator.py @@ -32,11 +32,13 @@ def generate_and_send_logs( def logging_decorator(func): - """Sends the appropriate logs to Cloudwatch and Firehose based on the function result. + """ + Sends the appropriate logs to Cloudwatch and Firehose based on the function result. NOTE: The function must return a dictionary as its only return value. The dictionary is expected to contain all of the required additional details for logging. NOTE: Logs will include the result of the function call or, in the case of an Exception being raised, - a status code of 500 and the error message.""" + a status code of 500 and the error message. + """ @wraps(func) def wrapper(*args, **kwargs): diff --git a/recordprocessor/src/batch_processing.py b/recordprocessor/src/batch_processing.py index 81b5bd7c8..7c25ccff2 100644 --- a/recordprocessor/src/batch_processing.py +++ b/recordprocessor/src/batch_processing.py @@ -3,15 +3,11 @@ import json import os import time -from constants import Constants -from utils_for_recordprocessor import get_csv_content_dict_reader -from unique_permission import get_unique_action_flags_from_s3 -from make_and_upload_ack_file import make_and_upload_ack_file -from get_operation_permissions import get_operation_permissions from process_row import process_row -from mappings import Vaccine from send_to_kinesis import send_to_kinesis from clients import logger +from file_level_validation import file_level_validation +from errors import NoOperationPermissions, InvalidHeaders def process_csv_to_fhir(incoming_message_body: dict) -> None: @@ -20,91 +16,45 @@ def process_csv_to_fhir(incoming_message_body: dict) -> None: and documents the outcome for each row in the ack file. """ logger.info("Event: %s", incoming_message_body) - # Get details needed to process file - file_id = incoming_message_body.get("message_id") - vaccine: Vaccine = next( # Convert vaccine_type to Vaccine enum - vaccine for vaccine in Vaccine if vaccine.value == incoming_message_body.get("vaccine_type").upper() - ) - supplier = incoming_message_body.get("supplier").upper() - file_key = incoming_message_body.get("filename") - permission = incoming_message_body.get("permission") - created_at_formatted_string = incoming_message_body.get("created_at_formatted_string") - allowed_operations = get_operation_permissions(vaccine, permission) - # Fetch the data - bucket_name = os.getenv("SOURCE_BUCKET_NAME") - csv_reader, csv_data = get_csv_content_dict_reader(bucket_name, file_key) - is_valid_headers = validate_content_headers(csv_reader) - - # Validate has permission to perform at least one of the requested actions - action_flag_check = validate_action_flag_permissions(supplier, vaccine.value, permission, csv_data) - - if not action_flag_check or not is_valid_headers: - make_and_upload_ack_file(file_id, file_key, False, False, created_at_formatted_string) - else: - # Initialise the accumulated_ack_file_content with the headers - make_and_upload_ack_file(file_id, file_key, True, True, created_at_formatted_string) - - row_count = 0 # Initialize a counter for rows - for row in csv_reader: - row_count += 1 - row_id = f"{file_id}^{row_count}" - logger.info("MESSAGE ID : %s", row_id) - # Process the row to obtain the details needed for the message_body and ack file - details_from_processing = process_row(vaccine, allowed_operations, row) - - # Create the message body for sending - outgoing_message_body = { - "row_id": row_id, - "file_key": file_key, - "supplier": supplier, - "vax_type": vaccine.value, - "created_at_formatted_string": created_at_formatted_string, - **details_from_processing, - } - - send_to_kinesis(supplier, outgoing_message_body) + try: + interim_message_body = file_level_validation(incoming_message_body=incoming_message_body) + except (InvalidHeaders, NoOperationPermissions, Exception): # pylint: disable=broad-exception-caught + # If the file is invalid, processing should cease immediately + return None + + file_id = interim_message_body.get("message_id") + vaccine = interim_message_body.get("vaccine") + supplier = interim_message_body.get("supplier") + file_key = interim_message_body.get("file_key") + allowed_operations = interim_message_body.get("allowed_operations") + created_at_formatted_string = interim_message_body.get("created_at_formatted_string") + csv_reader = interim_message_body.get("csv_dict_reader") + + row_count = 0 # Initialize a counter for rows + for row in csv_reader: + row_count += 1 + row_id = f"{file_id}^{row_count}" + logger.info("MESSAGE ID : %s", row_id) + + # Process the row to obtain the details needed for the message_body and ack file + details_from_processing = process_row(vaccine, allowed_operations, row) + + # Create the message body for sending + outgoing_message_body = { + "row_id": row_id, + "file_key": file_key, + "supplier": supplier, + "vax_type": vaccine.value, + "created_at_formatted_string": created_at_formatted_string, + **details_from_processing, + } + + send_to_kinesis(supplier, outgoing_message_body) logger.info("Total rows processed: %s", row_count) -def validate_content_headers(csv_content_reader): - """Returns a bool to indicate whether the given CSV headers match the 34 expected headers exactly""" - return csv_content_reader.fieldnames == Constants.expected_csv_headers - - -def validate_action_flag_permissions( - supplier: str, vaccine_type: str, allowed_permissions_list: list, csv_data -) -> bool: - """ - Returns True if the supplier has permission to perform ANY of the requested actions for the given vaccine type, - else False. - """ - # If the supplier has full permissions for the vaccine type, return True - if f"{vaccine_type}_FULL" in allowed_permissions_list: - return True - - # Get unique ACTION_FLAG values from the S3 file - operations_requested = get_unique_action_flags_from_s3(csv_data) - - # Convert action flags into the expected operation names - requested_permissions_set = { - f"{vaccine_type}_{'CREATE' if action == 'NEW' else action}" for action in operations_requested - } - - # Check if any of the CSV permissions match the allowed permissions - if requested_permissions_set.intersection(allowed_permissions_list): - logger.info( - "%s permissions %s match one of the requested permissions required to %s", - supplier, - allowed_permissions_list, - requested_permissions_set, - ) - return True - - return False - - def main(event: str) -> None: """Process each row of the file""" logger.info("task started") @@ -114,7 +64,7 @@ def main(event: str) -> None: except Exception as error: # pylint: disable=broad-exception-caught logger.error("Error processing message: %s", error) end = time.time() - logger.info(f"Total time for completion:{round(end - start, 5)}s") + logger.info("Total time for completion: %ss", round(end - start, 5)) if __name__ == "__main__": diff --git a/recordprocessor/src/clients.py b/recordprocessor/src/clients.py index f64627db2..762d5807c 100644 --- a/recordprocessor/src/clients.py +++ b/recordprocessor/src/clients.py @@ -7,6 +7,11 @@ s3_client = boto3_client("s3", region_name=REGION_NAME) kinesis_client = boto3_client("kinesis", region_name=REGION_NAME) +sqs_client = boto3_client("sqs", region_name=REGION_NAME) +firehose_client = boto3_client("firehose", region_name=REGION_NAME) + +# Logger logging.basicConfig(level="INFO") logger = logging.getLogger() +logger.setLevel("INFO") diff --git a/recordprocessor/src/errors.py b/recordprocessor/src/errors.py new file mode 100644 index 000000000..bde724576 --- /dev/null +++ b/recordprocessor/src/errors.py @@ -0,0 +1,9 @@ +"""Custom exceptions for the Record Processor.""" + + +class NoOperationPermissions(Exception): + """A custom exception for when the supplier has no permissions for any of the requested operations.""" + + +class InvalidHeaders(Exception): + """A custom exception for when the file headers are invalid.""" diff --git a/recordprocessor/src/file_level_validation.py b/recordprocessor/src/file_level_validation.py new file mode 100644 index 000000000..61019cc30 --- /dev/null +++ b/recordprocessor/src/file_level_validation.py @@ -0,0 +1,100 @@ +""" +Functions for completing file-level validation +(validating headers and ensuring that the supplier has permission to perform at least one of the requested operations) +""" + +from constants import Constants +from unique_permission import get_unique_action_flags_from_s3 +from clients import logger +from make_and_upload_ack_file import make_and_upload_ack_file +from mappings import Vaccine +from utils_for_recordprocessor import get_csv_content_dict_reader +from errors import InvalidHeaders, NoOperationPermissions +from logging_decorator import file_level_validation_logging_decorator + + +def validate_content_headers(csv_content_reader) -> None: + """Raises an InvalidHeaders error if the headers in the CSV file do not match the expected headers.""" + if csv_content_reader.fieldnames != Constants.expected_csv_headers: + raise InvalidHeaders("File headers are invalid.") + + +def validate_action_flag_permissions( + supplier: str, vaccine_type: str, allowed_permissions_list: list, csv_data: str +) -> set: + """ + Validates that the supplier has permission to perform at least one of the requested operations for the given + vaccine type and returns the set of allowed operations for that vaccine type. + Raises a NoPermissionsError if the supplier does not have permission to perform any of the requested operations. + """ + # If the supplier has full permissions for the vaccine type, return True + if f"{vaccine_type}_FULL" in allowed_permissions_list: + return {"CREATE", "UPDATE", "DELETE"} + + # Get unique ACTION_FLAG values from the S3 file + operations_requested = get_unique_action_flags_from_s3(csv_data) + + # Convert action flags into the expected operation names + requested_permissions_set = { + f"{vaccine_type}_{'CREATE' if action == 'NEW' else action}" for action in operations_requested + } + + # Check if any of the CSV permissions match the allowed permissions + if not requested_permissions_set.intersection(allowed_permissions_list): + raise NoOperationPermissions(f"{supplier} does not have permissions to perform any of the requested actions.") + + logger.info( + "%s permissions %s match one of the requested permissions required to %s", + supplier, + allowed_permissions_list, + requested_permissions_set, + ) + return {perm.split("_")[1].upper() for perm in allowed_permissions_list if perm.startswith(vaccine_type)} + + +@file_level_validation_logging_decorator +def file_level_validation(incoming_message_body: dict) -> None: + """Validates that the csv headers are correct and that the supplier has permission to perform at least one of + the requested operations. Returns an interim message body for row level processing.""" + try: + message_id = incoming_message_body.get("message_id") + vaccine: Vaccine = next( # Convert vaccine_type to Vaccine enum + vaccine for vaccine in Vaccine if vaccine.value == incoming_message_body.get("vaccine_type").upper() + ) + supplier = incoming_message_body.get("supplier").upper() + file_key = incoming_message_body.get("filename") + permission = incoming_message_body.get("permission") + created_at_formatted_string = incoming_message_body.get("created_at_formatted_string") + + # Fetch the data + csv_reader, csv_data = get_csv_content_dict_reader(file_key) + + try: + validate_content_headers(csv_reader) + + # Validate has permission to perform at least one of the requested actions + allowed_operations_set = validate_action_flag_permissions(supplier, vaccine.value, permission, csv_data) + except (InvalidHeaders, NoOperationPermissions): + make_and_upload_ack_file(message_id, file_key, False, False, created_at_formatted_string) + raise + + # Initialise the accumulated_ack_file_content with the headers + make_and_upload_ack_file(message_id, file_key, True, True, created_at_formatted_string) + + return { + "message_id": message_id, + "vaccine": vaccine, + "supplier": supplier, + "file_key": file_key, + "allowed_operations": allowed_operations_set, + "created_at_formatted_string": created_at_formatted_string, + "csv_dict_reader": csv_reader, + } + except Exception as error: + logger.error("Error in file_level_validation: %s", error) + # NOTE: The Exception may occur before the file_id, file_key and created_at_formatted_string are assigned + message_id = message_id or "Unable to ascertain message_id" + file_key = file_key or "Unable to ascertain file_key" + created_at_formatted_string = created_at_formatted_string or "Unable to ascertain created_at_formatted_string" + make_and_upload_ack_file(message_id, file_key, False, False, created_at_formatted_string) + raise diff --git a/recordprocessor/src/get_operation_permissions.py b/recordprocessor/src/get_operation_permissions.py deleted file mode 100644 index 57ad90a77..000000000 --- a/recordprocessor/src/get_operation_permissions.py +++ /dev/null @@ -1,12 +0,0 @@ -""""Functions for obtaining a dictionary of allowed action flags""" - -from mappings import Vaccine - - -def get_operation_permissions(vaccine: Vaccine, permission: str) -> set: - """Returns the set of allowed action flags.""" - return ( - {"CREATE", "UPDATE", "DELETE"} - if f"{vaccine.value}_FULL" in permission - else {perm.split("_")[1] for perm in permission if perm.startswith(vaccine.value)} - ) diff --git a/recordprocessor/src/logging_decorator.py b/recordprocessor/src/logging_decorator.py new file mode 100644 index 000000000..1a0ebd261 --- /dev/null +++ b/recordprocessor/src/logging_decorator.py @@ -0,0 +1,71 @@ +"""This module contains the logging decorator for sending the appropriate logs to Cloudwatch and Firehose.""" + +import json +import os +import time +from datetime import datetime +from functools import wraps +from clients import firehose_client, logger +from errors import NoOperationPermissions, InvalidHeaders + +STREAM_NAME = os.getenv("SPLUNK_FIREHOSE_NAME", "immunisation-fhir-api-internal-dev-splunk-firehose") + + +def send_log_to_firehose(log_data: dict) -> None: + """Sends the log_message to Firehose""" + try: + record = {"Data": json.dumps({"event": log_data}).encode("utf-8")} + response = firehose_client.put_record(DeliveryStreamName=STREAM_NAME, Record=record) + logger.info("Log sent to Firehose: %s", response) # TODO: Should we be logging full response? + except Exception as error: # pylint:disable = broad-exception-caught + logger.exception("Error sending log to Firehose: %s", error) + + +def generate_and_send_logs( + start_time, base_log_data: dict, additional_log_data: dict, is_error_log: bool = False +) -> None: + """Generates log data which includes the base_log_data, additional_log_data, and time taken (calculated using the + current time and given start_time) and sends them to Cloudwatch and Firehose.""" + log_data = {**base_log_data, "time_taken": f"{round(time.time() - start_time, 5)}s", **additional_log_data} + log_function = logger.error if is_error_log else logger.info + log_function(json.dumps(log_data)) + send_log_to_firehose(log_data) + + +def file_level_validation_logging_decorator(func): + """ + Sends the appropriate logs to Cloudwatch and Firehose based on the result of the file_level_validation + function call. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + incoming_message_body = kwargs.get("incoming_message_body") or args[0] + base_log_data = { + "function_name": f"record_processor_{func.__name__}", + "date_time": str(datetime.now()), + "file_key": incoming_message_body.get("filename"), + "message_id": incoming_message_body.get("message_id"), + "vaccine_type": incoming_message_body.get("vaccine_type"), + "supplier": incoming_message_body.get("supplier"), + } + start_time = time.time() + + try: + result = func(*args, **kwargs) + additional_log_data = {"statusCode": 200, "message": "Successfully sent for record processing"} + generate_and_send_logs(start_time, base_log_data, additional_log_data=additional_log_data) + return result + + except (InvalidHeaders, NoOperationPermissions, Exception) as e: + message = ( + str(e) if (isinstance(e, InvalidHeaders) or isinstance(e, NoOperationPermissions)) else "Server error" + ) + status_code = ( + 400 if isinstance(e, InvalidHeaders) else 403 if isinstance(e, NoOperationPermissions) else 500 + ) + additional_log_data = {"statusCode": status_code, "message": message, "error": str(e)} + generate_and_send_logs(start_time, base_log_data, additional_log_data, is_error_log=True) + raise + + return wrapper diff --git a/recordprocessor/src/update_ack_file.py b/recordprocessor/src/update_ack_file.py deleted file mode 100644 index ccdd6ae51..000000000 --- a/recordprocessor/src/update_ack_file.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Functions for adding a row of data to the ack file""" - -from io import StringIO, BytesIO -import os -from typing import Union -from clients import s3_client -from utils_for_recordprocessor import get_environment - - -def create_ack_data( - created_at_formatted_string: str, - row_id: str, - delivered: bool, - diagnostics: Union[None, str], - imms_id: Union[None, str], -) -> dict: - """Returns a dictionary containing the ack headers as keys, along with the relevant values.""" - return { - "MESSAGE_HEADER_ID": row_id, - "HEADER_RESPONSE_CODE": "OK" if (delivered and not diagnostics) else "Fatal Error", - "ISSUE_SEVERITY": "Information" if not diagnostics else "Fatal", - "ISSUE_CODE": "OK" if not diagnostics else "Fatal Error", - "ISSUE_DETAILS_CODE": "30001" if not diagnostics else "30002", - "RESPONSE_TYPE": "Business", - "RESPONSE_CODE": "30001" if (delivered and not diagnostics) else "30002", - "RESPONSE_DISPLAY": ( - "Success" if (delivered and not diagnostics) else "Business Level Response Value - Processing Error" - ), - "RECEIVED_TIME": created_at_formatted_string, - "MAILBOX_FROM": "", # TODO: Leave blank for DPS, use mailbox name if picked up from MESH mail box - "LOCAL_ID": "", # TODO: Leave blank for DPS, obtain from ctl file if picked up from MESH mail box - "IMMS_ID": imms_id or "", - "OPERATION_OUTCOME": diagnostics or "", - "MESSAGE_DELIVERY": delivered, - } - - -def add_row_to_ack_file(ack_data: dict, accumulated_ack_file_content: StringIO, file_key: str) -> StringIO: - """Adds the data row to the uploaded ack file""" - data_row_str = [str(item) for item in ack_data.values()] - cleaned_row = "|".join(data_row_str).replace(" |", "|").replace("| ", "|").strip() - accumulated_ack_file_content.write(cleaned_row + "\n") - csv_file_like_object = BytesIO(accumulated_ack_file_content.getvalue().encode("utf-8")) - ack_bucket_name = os.getenv("ACK_BUCKET_NAME", f"immunisation-batch-{get_environment()}-data-destinations") - ack_filename = f"processedFile/{file_key.replace('.csv', '_response.csv')}" - s3_client.upload_fileobj(csv_file_like_object, ack_bucket_name, ack_filename) - return accumulated_ack_file_content - - -def update_ack_file( - file_key: str, - bucket_name: str, - accumulated_ack_file_content: StringIO, - row_id: str, - message_delivered: bool, - diagnostics: Union[None, str], - imms_id: Union[None, str], -) -> StringIO: - """Updates the ack file with the new data row based on the given arguments""" - response = s3_client.head_object(Bucket=bucket_name, Key=file_key) - created_at_formatted_string = response["LastModified"].strftime("%Y%m%dT%H%M%S00") - ack_data_row = create_ack_data(created_at_formatted_string, row_id, message_delivered, diagnostics, imms_id) - accumulated_ack_file_content = add_row_to_ack_file(ack_data_row, accumulated_ack_file_content, file_key) - return accumulated_ack_file_content diff --git a/recordprocessor/src/utils_for_recordprocessor.py b/recordprocessor/src/utils_for_recordprocessor.py index c246c1b9f..6eb4401cf 100644 --- a/recordprocessor/src/utils_for_recordprocessor.py +++ b/recordprocessor/src/utils_for_recordprocessor.py @@ -13,8 +13,8 @@ def get_environment() -> str: return _env if _env in ["internal-dev", "int", "ref", "sandbox", "prod"] else "internal-dev" -def get_csv_content_dict_reader(bucket_name: str, file_key: str) -> DictReader: - """Returns the requested file contents in the form of a DictReader""" - response = s3_client.get_object(Bucket=bucket_name, Key=file_key) +def get_csv_content_dict_reader(file_key: str) -> DictReader: + """Returns the requested file contents from the source bucket in the form of a DictReader""" + response = s3_client.get_object(Bucket=os.getenv("SOURCE_BUCKET_NAME"), Key=file_key) csv_data = response["Body"].read().decode("utf-8") return DictReader(StringIO(csv_data), delimiter="|"), csv_data diff --git a/recordprocessor/tests/test_initial_file_validation.py b/recordprocessor/tests/test_file_level_validation.py similarity index 60% rename from recordprocessor/tests/test_initial_file_validation.py rename to recordprocessor/tests/test_file_level_validation.py index aa1bd345e..22e916768 100644 --- a/recordprocessor/tests/test_initial_file_validation.py +++ b/recordprocessor/tests/test_file_level_validation.py @@ -1,11 +1,12 @@ -"""Tests for initial file validation functions""" +"""Tests for file level validation functions""" import unittest from unittest.mock import patch # If mock_s3 is not imported here then tests in other files fail. It is not clear why this is. from moto import mock_s3 # noqa: F401 -from batch_processing import validate_content_headers, validate_action_flag_permissions +from file_level_validation import validate_content_headers, validate_action_flag_permissions +from errors import NoOperationPermissions, InvalidHeaders from tests.utils_for_recordprocessor_tests.utils_for_recordprocessor_tests import convert_string_to_dict_reader from tests.utils_for_recordprocessor_tests.values_for_recordprocessor_tests import ( MOCK_ENVIRONMENT_DICT, @@ -17,27 +18,30 @@ @patch.dict("os.environ", MOCK_ENVIRONMENT_DICT) -class TestInitialFileValidation(unittest.TestCase): - """Tests for the initial file validation functions""" +class TestFileLevelValidation(unittest.TestCase): + """Tests for the file level validation functions""" def test_validate_content_headers(self): "Tests that validate_content_headers returns True for an exact header match and False otherwise" - # Test case tuples are stuctured as (file_content, expected_result) - test_cases = [ - (ValidMockFileContent.with_new_and_update, True), # Valid file content - (ValidMockFileContent.with_new_and_update.replace("SITE_CODE", "SITE_COVE"), False), # Misspelled header - (ValidMockFileContent.with_new_and_update.replace("SITE_CODE|", ""), False), # Missing header - ( - ValidMockFileContent.with_new_and_update.replace("PERSON_DOB|", "PERSON_DOB|EXTRA_HEADER|"), - False, - ), # Extra header + + # Case: Valid file content + # validate_content_headers takes a csv dict reader as it's input + test_data = convert_string_to_dict_reader(ValidMockFileContent.with_new_and_update) + self.assertIsNone(validate_content_headers(test_data)) + + # Case: Invalid file content + invalid_file_contents = [ + ValidMockFileContent.with_new_and_update.replace("SITE_CODE", "SITE_COVE"), # Misspelled header + ValidMockFileContent.with_new_and_update.replace("SITE_CODE|", ""), # Missing header + ValidMockFileContent.with_new_and_update.replace("PERSON_DOB|", "PERSON_DOB|EXTRA_HEADER|"), # Extra header ] - for file_content, expected_result in test_cases: + for invalid_file_content in invalid_file_contents: with self.subTest(): # validate_content_headers takes a csv dict reader as it's input - test_data = convert_string_to_dict_reader(file_content) - self.assertEqual(validate_content_headers(test_data), expected_result) + test_data = convert_string_to_dict_reader(invalid_file_content) + with self.assertRaises(InvalidHeaders): + validate_content_headers(test_data) def test_validate_action_flag_permissions(self): """ @@ -54,46 +58,56 @@ def test_validate_action_flag_permissions(self): "update", "UPDATE" ) - # Test case tuples are stuctured as (vaccine_type, vaccine_permissions, file_content, expected_result) + # Case: Supplier has permissions to perform at least one of the requested operations + # Test case tuples are stuctured as (vaccine_type, vaccine_permissions, file_content, expected_output) test_cases = [ # FLU, full permissions, lowercase action flags - ("FLU", ["FLU_FULL"], valid_content_new_and_update_lowercase, True), + ("FLU", ["FLU_FULL"], valid_content_new_and_update_lowercase, {"CREATE", "UPDATE", "DELETE"}), # FLU, partial permissions, uppercase action flags - ("FLU", ["FLU_CREATE"], valid_content_new_and_update_uppercase, True), + ("FLU", ["FLU_CREATE"], valid_content_new_and_update_uppercase, {"CREATE"}), # FLU, full permissions, mixed case action flags - ("FLU", ["FLU_FULL"], valid_content_new_and_update_mixedcase, True), + ("FLU", ["FLU_FULL"], valid_content_new_and_update_mixedcase, {"CREATE", "UPDATE", "DELETE"}), # FLU, partial permissions (create) - ("FLU", ["FLU_DELETE", "FLU_CREATE"], valid_content_new_and_update_lowercase, True), + ("FLU", ["FLU_DELETE", "FLU_CREATE"], valid_content_new_and_update_lowercase, {"CREATE", "DELETE"}), # FLU, partial permissions (update) - ("FLU", ["FLU_UPDATE"], valid_content_new_and_update_lowercase, True), + ("FLU", ["FLU_UPDATE"], valid_content_new_and_update_lowercase, {"UPDATE"}), # FLU, partial permissions (delete) - ("FLU", ["FLU_DELETE"], valid_content_new_and_delete_lowercase, True), - # FLU, no permissions - ("FLU", ["FLU_UPDATE", "COVID19_FULL"], valid_content_new_and_delete_lowercase, False), + ("FLU", ["FLU_DELETE"], valid_content_new_and_delete_lowercase, {"DELETE"}), # COVID19, full permissions - ("COVID19", ["COVID19_FULL"], valid_content_new_and_delete_lowercase, True), + ("COVID19", ["COVID19_FULL"], valid_content_new_and_delete_lowercase, {"CREATE", "UPDATE", "DELETE"}), # COVID19, partial permissions - ("COVID19", ["COVID19_UPDATE"], valid_content_update_and_delete_lowercase, True), - # COVID19, no permissions - ("COVID19", ["FLU_CREATE", "FLU_UPDATE"], valid_content_update_and_delete_lowercase, False), + ("COVID19", ["COVID19_UPDATE"], valid_content_update_and_delete_lowercase, {"UPDATE"}), # RSV, full permissions - ("RSV", ["RSV_FULL"], valid_content_new_and_delete_lowercase, True), + ("RSV", ["RSV_FULL"], valid_content_new_and_delete_lowercase, {"CREATE", "UPDATE", "DELETE"}), # RSV, partial permissions - ("RSV", ["RSV_UPDATE"], valid_content_update_and_delete_lowercase, True), - # RSV, no permissions - ("RSV", ["FLU_CREATE", "FLU_UPDATE"], valid_content_update_and_delete_lowercase, False), + ("RSV", ["RSV_UPDATE"], valid_content_update_and_delete_lowercase, {"UPDATE"}), # RSV, full permissions, mixed case action flags - ("RSV", ["RSV_FULL"], valid_content_new_and_update_mixedcase, True), + ("RSV", ["RSV_FULL"], valid_content_new_and_update_mixedcase, {"CREATE", "UPDATE", "DELETE"}), ] - for vaccine_type, vaccine_permissions, file_content, expected_result in test_cases: - with self.subTest(): - # validate_action_flag_permissions takes a csv dict reader as one of it's args + for vaccine_type, vaccine_permissions, file_content, expected_output in test_cases: + with self.subTest(f"Vaccine_type {vaccine_type} - permissions {vaccine_permissions}"): self.assertEqual( validate_action_flag_permissions("TEST_SUPPLIER", vaccine_type, vaccine_permissions, file_content), - expected_result, + expected_output, ) + # Case: Supplier has no permissions to perform any of the requested operations + # Test case tuples are stuctured as (vaccine_type, vaccine_permissions, file_content) + test_cases = [ + # FLU, no permissions + ("FLU", ["FLU_UPDATE", "COVID19_FULL"], valid_content_new_and_delete_lowercase), + # COVID19, no permissions + ("COVID19", ["FLU_CREATE", "FLU_UPDATE"], valid_content_update_and_delete_lowercase), + # RSV, no permissions + ("RSV", ["FLU_CREATE", "FLU_UPDATE"], valid_content_update_and_delete_lowercase), + ] + + for vaccine_type, vaccine_permissions, file_content in test_cases: + with self.subTest(): + with self.assertRaises(NoOperationPermissions): + validate_action_flag_permissions("TEST_SUPPLIER", vaccine_type, vaccine_permissions, file_content) + if __name__ == "__main__": unittest.main() diff --git a/recordprocessor/tests/test_logging_decorator.py b/recordprocessor/tests/test_logging_decorator.py new file mode 100644 index 000000000..fa4bf3484 --- /dev/null +++ b/recordprocessor/tests/test_logging_decorator.py @@ -0,0 +1,232 @@ +"""Tests for the logging_decorator and its helper functions""" + +import unittest +from unittest.mock import patch +from datetime import datetime +import json +from copy import deepcopy +from boto3 import client as boto3_client +from moto import mock_s3, mock_firehose +from file_level_validation import file_level_validation +from logging_decorator import send_log_to_firehose, generate_and_send_logs +from clients import REGION_NAME +from errors import InvalidHeaders, NoOperationPermissions +from tests.utils_for_recordprocessor_tests.values_for_recordprocessor_tests import ( + MOCK_ENVIRONMENT_DICT, + MockFileDetails, + BucketNames, + ValidMockFileContent, + Firehose, +) +from tests.utils_for_recordprocessor_tests.utils_for_recordprocessor_tests import GenericSetUp, GenericTearDown + +s3_client = boto3_client("s3", region_name=REGION_NAME) +firehose_client = boto3_client("firehose", region_name=REGION_NAME) +MOCK_FILE_DETAILS = MockFileDetails.flu_emis +COMMON_LOG_DATA = { + "function_name": "record_processor_file_level_validation", + "date_time": "2024-01-01 12:00:00", # (tests mock a 2024-01-01 12:00:00 datetime) + "time_taken": "0.12346s", # Time taken is rounded to 5 decimal places (tests mock a 0.123456s time taken) + "file_key": MOCK_FILE_DETAILS.file_key, + "message_id": MOCK_FILE_DETAILS.message_id, + "vaccine_type": MOCK_FILE_DETAILS.vaccine_type, + "supplier": MOCK_FILE_DETAILS.supplier, +} + + +@mock_s3 +@mock_firehose +@patch.dict("os.environ", MOCK_ENVIRONMENT_DICT) +class TestLoggingDecorator(unittest.TestCase): + """Tests for the logging_decorator and its helper functions""" + + def setUp(self): + """Set up the S3 buckets and upload the valid FLU/EMIS file example""" + GenericSetUp(s3_client, firehose_client) + + def tearDown(self): + GenericTearDown(s3_client, firehose_client) + + def test_send_log_to_firehose(self): + """ + Tests that the send_log_to_firehose function calls firehose_client.put_record with the correct arguments. + NOTE: mock_firehose does not persist the data, so at this level it is only possible to test what the call args + were, not that the data reached the destination. + """ + log_data = {"test_key": "test_value"} + + with patch("logging_decorator.firehose_client") as mock_firehose_client: + send_log_to_firehose(log_data) + + expected_firehose_record = {"Data": json.dumps({"event": log_data}).encode("utf-8")} + mock_firehose_client.put_record.assert_called_once_with( + DeliveryStreamName=Firehose.STREAM_NAME, Record=expected_firehose_record + ) + + def test_generate_and_send_logs(self): + """ + Tests that the generate_and_send_logs function logs the correct data at the correct level for cloudwatch + and calls send_log_to_firehose with the correct log data + """ + base_log_data = {"base_key": "base_value"} + additional_log_data = {"additional_key": "additional_value"} + start_time = 1672531200 + + # CASE: Successful log - is_error_log arg set to False + with ( # noqa: E999 + patch("logging_decorator.logger") as mock_logger, # noqa: E999 + patch("logging_decorator.send_log_to_firehose") as mock_send_log_to_firehose, # noqa: E999 + patch("logging_decorator.time") as mock_time, # noqa: E999 + ): # noqa: E999 + mock_time.time.return_value = 1672531200.123456 # Mocks the end time to be 0.123456s after the start time + generate_and_send_logs(start_time, base_log_data, additional_log_data, is_error_log=False) + + expected_log_data = {"base_key": "base_value", "time_taken": "0.12346s", "additional_key": "additional_value"} + log_data = json.loads(mock_logger.info.call_args[0][0]) + self.assertEqual(log_data, expected_log_data) + mock_send_log_to_firehose.assert_called_once_with(expected_log_data) + + # CASE: Error log - is_error_log arg set to True + with ( # noqa: E999 + patch("logging_decorator.logger") as mock_logger, # noqa: E999 + patch("logging_decorator.send_log_to_firehose") as mock_send_log_to_firehose, # noqa: E999 + patch("logging_decorator.time") as mock_time, # noqa: E999 + ): # noqa: E999 + mock_time.time.return_value = 1672531200.123456 # Mocks the end time to be 0.123456s after the start time + generate_and_send_logs(start_time, base_log_data, additional_log_data, is_error_log=True) + + expected_log_data = {"base_key": "base_value", "time_taken": "0.12346s", "additional_key": "additional_value"} + log_data = json.loads(mock_logger.error.call_args[0][0]) + self.assertEqual(log_data, expected_log_data) + mock_send_log_to_firehose.assert_called_once_with(expected_log_data) + + def test_splunk_logger_successful_validation(self): + """Tests the splunk logger is called when file-level validation is successful""" + + s3_client.put_object( + Bucket=BucketNames.SOURCE, + Key=MOCK_FILE_DETAILS.file_key, + Body=ValidMockFileContent.with_new_and_update_and_delete, + ) + + with ( # noqa: E999 + patch("logging_decorator.datetime") as mock_datetime, # noqa: E999 + patch("logging_decorator.time") as mock_time, # noqa: E999 + patch("logging_decorator.logger") as mock_logger, # noqa: E999 + patch("logging_decorator.firehose_client") as mock_firehose_client, # noqa: E999 + ): # noqa: E999 + mock_time.time.side_effect = [1672531200, 1672531200.123456] + mock_datetime.now.return_value = datetime(2024, 1, 1, 12, 0, 0) + file_level_validation(deepcopy(MOCK_FILE_DETAILS.event_full_permissions_dict)) + + expected_message = "Successfully sent for record processing" + expected_log_data = {**COMMON_LOG_DATA, "statusCode": 200, "message": expected_message} + + # Log data is the first positional argument of the first call to logger.info + log_data = json.loads(mock_logger.info.call_args_list[0][0][0]) + self.assertEqual(log_data, expected_log_data) + + expected_firehose_record = {"Data": json.dumps({"event": log_data}).encode("utf-8")} + mock_firehose_client.put_record.assert_called_once_with( + DeliveryStreamName=Firehose.STREAM_NAME, Record=expected_firehose_record + ) + + def test_splunk_logger_handled_failure(self): + """Tests the splunk logger is called when file-level validation fails for a known reason""" + + # Test case tuples are structured as (file_content, event_dict, expected_error_type, + # expected_status_code, expected_error_message) + test_cases = [ + # CASE: Invalid headers + ( + ValidMockFileContent.with_new_and_update_and_delete.replace("NHS_NUMBER", "NHS_NUMBERS"), + MOCK_FILE_DETAILS.event_full_permissions_dict, + InvalidHeaders, + 400, + "File headers are invalid.", + ), + # CASE: No operation permissions + ( + ValidMockFileContent.with_new_and_update, + MOCK_FILE_DETAILS.event_delete_permissions_only_dict, # No permission for NEW or UPDATE + NoOperationPermissions, + 403, + f"{MOCK_FILE_DETAILS.supplier} does not have permissions to perform any of the requested actions.", + ), + ] + + for ( + mock_file_content, + event_dict, + expected_error_type, + expected_status_code, + expected_error_message, + ) in test_cases: + with self.subTest(expected_error_message): + + s3_client.put_object(Bucket=BucketNames.SOURCE, Key=MOCK_FILE_DETAILS.file_key, Body=mock_file_content) + + with ( # noqa: E999 + patch("logging_decorator.datetime") as mock_datetime, # noqa: E999 + patch("logging_decorator.time") as mock_time, # noqa: E999 + patch("logging_decorator.logger") as mock_logger, # noqa: E999 + patch("logging_decorator.firehose_client") as mock_firehose_client, # noqa: E999 + ): # noqa: E999 + mock_datetime.now.return_value = datetime(2024, 1, 1, 12, 0, 0) + mock_time.time.side_effect = [1672531200, 1672531200.123456] + with self.assertRaises(expected_error_type): + file_level_validation(deepcopy(event_dict)) + + expected_log_data = { + **COMMON_LOG_DATA, + "statusCode": expected_status_code, + "message": expected_error_message, + "error": expected_error_message, + } + + # Log data is the first positional argument of the first call to logger.error + log_data = json.loads(mock_logger.error.call_args_list[0][0][0]) + self.assertEqual(log_data, expected_log_data) + + expected_firehose_record = {"Data": json.dumps({"event": log_data}).encode("utf-8")} + mock_firehose_client.put_record.assert_called_once_with( + DeliveryStreamName=Firehose.STREAM_NAME, Record=expected_firehose_record + ) + + def test_splunk_logger_unhandled_failure(self): + """Tests the splunk logger is called when file-level validation fails for an unknown reason""" + s3_client.put_object( + Bucket=BucketNames.SOURCE, + Key=MOCK_FILE_DETAILS.file_key, + Body=ValidMockFileContent.with_new_and_update_and_delete, + ) + + with ( # noqa: E999 + patch("logging_decorator.datetime") as mock_datetime, # noqa: E999 + patch("logging_decorator.time") as mock_time, # noqa: E999 + patch("logging_decorator.logger") as mock_logger, # noqa: E999 + patch("logging_decorator.firehose_client") as mock_firehose_client, # noqa: E999 + patch( + "file_level_validation.validate_content_headers", side_effect=Exception("Test exception") + ), # noqa: E999 + ): # noqa: E999 + mock_time.time.side_effect = [1672531200, 1672531200.123456] + mock_datetime.now.return_value = datetime(2024, 1, 1, 12, 0, 0) + with self.assertRaises(Exception): + file_level_validation(deepcopy(MOCK_FILE_DETAILS.event_full_permissions_dict)) + + expected_log_data = { + **COMMON_LOG_DATA, + "statusCode": 500, + "message": "Server error", + "error": "Test exception", + } + + # Log data is the first positional argument of the first call to logger.error + log_data = json.loads(mock_logger.error.call_args_list[0][0][0]) + self.assertEqual(log_data, expected_log_data) + + expected_firehose_record = {"Data": json.dumps({"event": log_data}).encode("utf-8")} + mock_firehose_client.put_record.assert_called_once_with( + DeliveryStreamName=Firehose.STREAM_NAME, Record=expected_firehose_record + ) diff --git a/recordprocessor/tests/test_update_ack_file.py b/recordprocessor/tests/test_make_and_upload_ack_file.py similarity index 95% rename from recordprocessor/tests/test_update_ack_file.py rename to recordprocessor/tests/test_make_and_upload_ack_file.py index 9d9f7c430..f782fa7f1 100644 --- a/recordprocessor/tests/test_update_ack_file.py +++ b/recordprocessor/tests/test_make_and_upload_ack_file.py @@ -1,11 +1,11 @@ -"""Tests for update_ack_file.py""" +"""Tests for make_and_upload_ack_file functions""" import unittest from make_and_upload_ack_file import make_ack_data -class TestUpdateAckFile(unittest.TestCase): - "Tests for update_ack_file.py" +class TestMakeAndUploadAckFile(unittest.TestCase): + "Tests for make_and_upload_ack_file functions" def setUp(self) -> None: self.message_id = "test_id" diff --git a/recordprocessor/tests/test_process_csv_to_fhir.py b/recordprocessor/tests/test_process_csv_to_fhir.py index 107a18b83..e6aab2229 100644 --- a/recordprocessor/tests/test_process_csv_to_fhir.py +++ b/recordprocessor/tests/test_process_csv_to_fhir.py @@ -4,8 +4,12 @@ from unittest.mock import patch import boto3 from copy import deepcopy -from moto import mock_s3 +from moto import mock_s3, mock_firehose from batch_processing import process_csv_to_fhir +from tests.utils_for_recordprocessor_tests.utils_for_recordprocessor_tests import ( + GenericSetUp, + GenericTearDown, +) from tests.utils_for_recordprocessor_tests.values_for_recordprocessor_tests import ( MOCK_ENVIRONMENT_DICT, MockFileDetails, @@ -15,23 +19,21 @@ ) s3_client = boto3.client("s3", region_name=REGION_NAME) +firehose_client = boto3.client("firehose", region_name=REGION_NAME) test_file = MockFileDetails.rsv_emis @patch.dict("os.environ", MOCK_ENVIRONMENT_DICT) @mock_s3 -class TestProcessLambdaFunction(unittest.TestCase): +@mock_firehose +class TestProcessCsvToFhir(unittest.TestCase): """Tests for process_csv_to_fhir function""" def setUp(self) -> None: - for bucket_name in [BucketNames.SOURCE, BucketNames.DESTINATION]: - s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": REGION_NAME}) + GenericSetUp(s3_client, firehose_client) def tearDown(self) -> None: - for bucket_name in [BucketNames.SOURCE, BucketNames.DESTINATION]: - for obj in s3_client.list_objects_v2(Bucket=bucket_name).get("Contents", []): - s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) - s3_client.delete_bucket(Bucket=bucket_name) + GenericTearDown(s3_client, firehose_client) @staticmethod def upload_source_file(file_key, file_content): diff --git a/recordprocessor/tests/test_lambda_e2e.py b/recordprocessor/tests/test_recordprocessor_main.py similarity index 84% rename from recordprocessor/tests/test_lambda_e2e.py rename to recordprocessor/tests/test_recordprocessor_main.py index 69274488d..22a53d0ee 100644 --- a/recordprocessor/tests/test_lambda_e2e.py +++ b/recordprocessor/tests/test_recordprocessor_main.py @@ -1,14 +1,18 @@ -"E2e tests for recordprocessor" +"Tests for main function for RecordProcessor" import unittest import json from decimal import Decimal from unittest.mock import patch from datetime import datetime, timedelta, timezone -from moto import mock_s3, mock_kinesis +from moto import mock_s3, mock_kinesis, mock_firehose from boto3 import client as boto3_client from batch_processing import main from constants import Diagnostics +from tests.utils_for_recordprocessor_tests.utils_for_recordprocessor_tests import ( + GenericSetUp, + GenericTearDown, +) from tests.utils_for_recordprocessor_tests.values_for_recordprocessor_tests import ( Kinesis, MOCK_ENVIRONMENT_DICT, @@ -25,6 +29,7 @@ s3_client = boto3_client("s3", region_name=REGION_NAME) kinesis_client = boto3_client("kinesis", region_name=REGION_NAME) +firehose_client = boto3_client("firehose", region_name=REGION_NAME) yesterday = datetime.now(timezone.utc) - timedelta(days=1) mock_rsv_emis_file = MockFileDetails.rsv_emis @@ -32,30 +37,18 @@ @patch.dict("os.environ", MOCK_ENVIRONMENT_DICT) @mock_s3 @mock_kinesis +@mock_firehose class TestRecordProcessor(unittest.TestCase): - """E2e tests for RecordProcessor""" + """Tests for main function for RecordProcessor""" def setUp(self) -> None: - for bucket_name in [BucketNames.SOURCE, BucketNames.DESTINATION]: - s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": REGION_NAME}) - - kinesis_client.create_stream(StreamName=Kinesis.STREAM_NAME, ShardCount=1) + GenericSetUp(s3_client, firehose_client, kinesis_client) def tearDown(self) -> None: - # Delete all of the buckets (the contents of each bucket must be deleted first) - for bucket_name in [BucketNames.SOURCE, BucketNames.DESTINATION]: - for obj in s3_client.list_objects_v2(Bucket=bucket_name).get("Contents", []): - s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) - s3_client.delete_bucket(Bucket=bucket_name) - - # Delete the kinesis stream - try: - kinesis_client.delete_stream(StreamName=Kinesis.STREAM_NAME, EnforceConsumerDeletion=True) - except kinesis_client.exceptions.ResourceNotFoundException: - pass + GenericTearDown(s3_client, firehose_client, kinesis_client) @staticmethod - def upload_files(source_file_content): # pylint: disable=dangerous-default-value + def upload_source_files(source_file_content): # pylint: disable=dangerous-default-value """Uploads a test file with the TEST_FILE_KEY (RSV EMIS file) the given file content to the source bucket""" s3_client.put_object(Bucket=BucketNames.SOURCE, Key=mock_rsv_emis_file.file_key, Body=source_file_content) @@ -146,7 +139,7 @@ def test_e2e_full_permissions(self): Tests that file containing CREATE, UPDATE and DELETE is successfully processed when the supplier has full permissions. """ - self.upload_files(ValidMockFileContent.with_new_and_update_and_delete) + self.upload_source_files(ValidMockFileContent.with_new_and_update_and_delete) main(mock_rsv_emis_file.event_full_permissions) @@ -180,7 +173,7 @@ def test_e2e_partial_permissions(self): Tests that file containing CREATE, UPDATE and DELETE is successfully processed when the supplier only has CREATE permissions. """ - self.upload_files(ValidMockFileContent.with_new_and_update_and_delete) + self.upload_source_files(ValidMockFileContent.with_new_and_update_and_delete) main(mock_rsv_emis_file.event_create_permissions_only) @@ -222,7 +215,7 @@ def test_e2e_no_permissions(self): Tests that file containing UPDATE and DELETE is successfully processed when the supplier has CREATE permissions only. """ - self.upload_files(ValidMockFileContent.with_update_and_delete) + self.upload_source_files(ValidMockFileContent.with_update_and_delete) main(mock_rsv_emis_file.event_create_permissions_only) @@ -233,7 +226,7 @@ def test_e2e_no_permissions(self): def test_e2e_invalid_action_flags(self): """Tests that file is successfully processed when the ACTION_FLAG field is empty or invalid.""" - self.upload_files( + self.upload_source_files( ValidMockFileContent.with_update_and_delete.replace("update", "").replace("delete", "INVALID") ) @@ -262,7 +255,7 @@ def test_e2e_differing_amounts_of_data(self): mandatory_fields_only_values = "|".join(f'"{v}"' for v in MockFieldDictionaries.mandatory_fields_only.values()) critical_fields_only_values = "|".join(f'"{v}"' for v in MockFieldDictionaries.critical_fields_only.values()) file_content = f"{headers}\n{all_fields_values}\n{mandatory_fields_only_values}\n{critical_fields_only_values}" - self.upload_files(file_content) + self.upload_source_files(file_content) main(mock_rsv_emis_file.event_full_permissions) @@ -298,14 +291,34 @@ def test_e2e_kinesis_failed(self): Tests that, for a file with valid content and supplier with full permissions, when the kinesis send fails, the ack file is created and documents an error. """ - self.upload_files(ValidMockFileContent.with_new_and_update) + self.upload_source_files(ValidMockFileContent.with_new_and_update) # Delete the kinesis stream, to cause kinesis send to fail kinesis_client.delete_stream(StreamName=Kinesis.STREAM_NAME, EnforceConsumerDeletion=True) - main(mock_rsv_emis_file.event_full_permissions) - + with ( # noqa: E999 + patch("logging_decorator.send_log_to_firehose") as mock_send_log_to_firehose, # noqa: E999 + patch("logging_decorator.datetime") as mock_datetime, # noqa: E999 + patch("logging_decorator.time") as mock_time, # noqa: E999 + ): # noqa: E999 + mock_time.time.side_effect = [1672531200, 1672531200.123456] + mock_datetime.now.return_value = datetime(2024, 1, 1, 12, 0, 0) + main(mock_rsv_emis_file.event_full_permissions) + + # Since the failure occured at row level, not file level, the ack file should still be created + # and firehose logs should indicate a successful file level validation self.make_inf_ack_assertions(file_details=mock_rsv_emis_file, passed_validation=True) - # TODO: Make assertions r.e. logs (there is no output as kinesis failed) + expected_log_data = { + "function_name": "record_processor_file_level_validation", + "date_time": "2024-01-01 12:00:00", + "file_key": "RSV_Vaccinations_v5_8HK48_20210730T12000000.csv", + "message_id": "rsv_emis_test_id", + "vaccine_type": "RSV", + "supplier": "EMIS", + "time_taken": "0.12346s", + "statusCode": 200, + "message": "Successfully sent for record processing", + } + mock_send_log_to_firehose.assert_called_with(expected_log_data) if __name__ == "__main__": diff --git a/recordprocessor/tests/test_utils_for_recordprocessor.py b/recordprocessor/tests/test_utils_for_recordprocessor.py index 4c9259e47..e527dfa32 100644 --- a/recordprocessor/tests/test_utils_for_recordprocessor.py +++ b/recordprocessor/tests/test_utils_for_recordprocessor.py @@ -7,6 +7,7 @@ import boto3 from moto import mock_s3 from utils_for_recordprocessor import get_environment, get_csv_content_dict_reader +from tests.utils_for_recordprocessor_tests.utils_for_recordprocessor_tests import GenericSetUp, GenericTearDown from tests.utils_for_recordprocessor_tests.values_for_recordprocessor_tests import ( MOCK_ENVIRONMENT_DICT, MockFileDetails, @@ -25,14 +26,10 @@ class TestUtilsForRecordprocessor(unittest.TestCase): """Tests for utils_for_recordprocessor""" def setUp(self) -> None: - for bucket_name in [BucketNames.SOURCE, BucketNames.DESTINATION]: - s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": REGION_NAME}) + GenericSetUp(s3_client) def tearDown(self) -> None: - for bucket_name in [BucketNames.SOURCE, BucketNames.DESTINATION]: - for obj in s3_client.list_objects_v2(Bucket=bucket_name).get("Contents", []): - s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) - s3_client.delete_bucket(Bucket=bucket_name) + GenericTearDown(s3_client) @staticmethod def upload_source_file(file_key, file_content): @@ -45,7 +42,7 @@ def test_get_csv_content_dict_reader(self): """Tests that get_csv_content_dict_reader returns the correct csv data""" self.upload_source_file(test_file.file_key, ValidMockFileContent.with_new_and_update) expected_output = csv.DictReader(StringIO(ValidMockFileContent.with_new_and_update), delimiter="|") - result, csv_data = get_csv_content_dict_reader(BucketNames.SOURCE, test_file.file_key) + result, csv_data = get_csv_content_dict_reader(test_file.file_key) self.assertEqual(list(result), list(expected_output)) self.assertEqual(csv_data, ValidMockFileContent.with_new_and_update) diff --git a/recordprocessor/tests/utils_for_recordprocessor_tests/utils_for_recordprocessor_tests.py b/recordprocessor/tests/utils_for_recordprocessor_tests/utils_for_recordprocessor_tests.py index 4dba3b25e..7f4c2d934 100644 --- a/recordprocessor/tests/utils_for_recordprocessor_tests/utils_for_recordprocessor_tests.py +++ b/recordprocessor/tests/utils_for_recordprocessor_tests/utils_for_recordprocessor_tests.py @@ -2,8 +2,67 @@ from csv import DictReader from io import StringIO +from tests.utils_for_recordprocessor_tests.values_for_recordprocessor_tests import ( + BucketNames, + REGION_NAME, + Firehose, + Kinesis, +) def convert_string_to_dict_reader(data_string: str): """Take a data string and convert it to a csv DictReader""" return DictReader(StringIO(data_string), delimiter="|") + + +class GenericSetUp: + """ + Performs generic setup of mock resources: + * If s3_client is provided, creates source, destination and firehose buckets (firehose bucket is used for testing + only) + * If firehose_client is provided, creates a firehose delivery stream + * If kinesis_client is provided, creates a kinesis stream + """ + + def __init__(self, s3_client=None, firehose_client=None, kinesis_client=None): + + if s3_client: + for bucket_name in [BucketNames.SOURCE, BucketNames.DESTINATION, BucketNames.MOCK_FIREHOSE]: + s3_client.create_bucket( + Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": REGION_NAME} + ) + + if firehose_client: + firehose_client.create_delivery_stream( + DeliveryStreamName=Firehose.STREAM_NAME, + DeliveryStreamType="DirectPut", + S3DestinationConfiguration={ + "RoleARN": "arn:aws:iam::123456789012:role/mock-role", + "BucketARN": "arn:aws:s3:::" + BucketNames.MOCK_FIREHOSE, + "Prefix": "firehose-backup/", + }, + ) + + if kinesis_client: + kinesis_client.create_stream(StreamName=Kinesis.STREAM_NAME, ShardCount=1) + + +class GenericTearDown: + """Performs generic tear down of mock resources""" + + def __init__(self, s3_client=None, firehose_client=None, kinesis_client=None): + + if s3_client: + for bucket_name in [BucketNames.SOURCE, BucketNames.DESTINATION]: + for obj in s3_client.list_objects_v2(Bucket=bucket_name).get("Contents", []): + s3_client.delete_object(Bucket=bucket_name, Key=obj["Key"]) + s3_client.delete_bucket(Bucket=bucket_name) + + if firehose_client: + firehose_client.delete_delivery_stream(DeliveryStreamName=Firehose.STREAM_NAME) + + if kinesis_client: + try: + kinesis_client.delete_stream(StreamName=Kinesis.STREAM_NAME, EnforceConsumerDeletion=True) + except kinesis_client.exceptions.ResourceNotFoundException: + pass diff --git a/recordprocessor/tests/utils_for_recordprocessor_tests/values_for_recordprocessor_tests.py b/recordprocessor/tests/utils_for_recordprocessor_tests/values_for_recordprocessor_tests.py index d9b8f1ba7..1e7c7dc1a 100644 --- a/recordprocessor/tests/utils_for_recordprocessor_tests/values_for_recordprocessor_tests.py +++ b/recordprocessor/tests/utils_for_recordprocessor_tests/values_for_recordprocessor_tests.py @@ -40,6 +40,7 @@ class BucketNames: SOURCE = "immunisation-batch-internal-dev-data-sources" DESTINATION = "immunisation-batch-internal-dev-data-destinations" + MOCK_FIREHOSE = "mock-firehose-bucket" class Kinesis: @@ -48,6 +49,12 @@ class Kinesis: STREAM_NAME = "imms-batch-internal-dev-processingdata-stream" +class Firehose: + """Class containing Firehose values for use in tests""" + + STREAM_NAME = "immunisation-fhir-api-internal-dev-splunk-firehose" + + MOCK_ENVIRONMENT_DICT = { "ENVIRONMENT": "internal-dev", "LOCAL_ACCOUNT_ID": "123456789012", @@ -56,6 +63,7 @@ class Kinesis: "SHORT_QUEUE_PREFIX": "imms-batch-internal-dev", "KINESIS_STREAM_NAME": Kinesis.STREAM_NAME, "KINESIS_STREAM_ARN": f"arn:aws:kinesis:{REGION_NAME}:123456789012:stream/{Kinesis.STREAM_NAME}", + "FIREHOSE_STREAM_NAME": Firehose.STREAM_NAME, } diff --git a/terraform/ecs_batch_processor_config.tf b/terraform/ecs_batch_processor_config.tf index be306d1ae..92e1b829d 100644 --- a/terraform/ecs_batch_processor_config.tf +++ b/terraform/ecs_batch_processor_config.tf @@ -6,13 +6,13 @@ resource "aws_ecs_cluster" "ecs_cluster" { # Locals for Lambda processing paths and hash locals { processing_lambda_dir = abspath("${path.root}/../recordprocessor") - processing_path_include = ["**"] - processing_path_exclude = ["**/__pycache__/**"] - processing_files_include = setunion([for f in local.processing_path_include : fileset(local.processing_lambda_dir, f)]...) - processing_files_exclude = setunion([for f in local.processing_path_exclude : fileset(local.processing_lambda_dir, f)]...) + processing_path_include = ["**"] + processing_path_exclude = ["**/__pycache__/**"] + processing_files_include = setunion([for f in local.processing_path_include : fileset(local.processing_lambda_dir, f)]...) + processing_files_exclude = setunion([for f in local.processing_path_exclude : fileset(local.processing_lambda_dir, f)]...) processing_lambda_files = sort(setsubtract(local.processing_files_include, local.processing_files_exclude)) processing_lambda_dir_sha = sha1(join("", [for f in local.processing_lambda_files : filesha1("${local.processing_lambda_dir}/${f}")])) - image_tag = "latest" + image_tag = "latest" } # Create ECR Repository for processing. @@ -27,10 +27,10 @@ resource "aws_ecr_repository" "processing_repository" { module "processing_docker_image" { source = "terraform-aws-modules/lambda/aws//modules/docker-build" - docker_file_path = "Dockerfile" - create_ecr_repo = false - ecr_repo = aws_ecr_repository.processing_repository.name - ecr_repo_lifecycle_policy = jsonencode({ + docker_file_path = "Dockerfile" + create_ecr_repo = false + ecr_repo = aws_ecr_repository.processing_repository.name + ecr_repo_lifecycle_policy = jsonencode({ "rules" : [ { "rulePriority" : 1, @@ -46,7 +46,7 @@ module "processing_docker_image" { } ] }) - + platform = "linux/amd64" use_image_tag = false source_path = local.processing_lambda_dir @@ -58,7 +58,7 @@ module "processing_docker_image" { # Define the IAM Role for ECS Task Execution resource "aws_iam_role" "ecs_task_exec_role" { name = "${local.short_prefix}-ecs-task-exec-role" - + assume_role_policy = jsonencode({ Version = "2012-10-17", Statement = [ @@ -74,19 +74,19 @@ resource "aws_iam_role" "ecs_task_exec_role" { } resource "aws_iam_role_policy_attachment" "task_execution_ecr_policy" { - role = aws_iam_role.ecs_task_exec_role.name - policy_arn = "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryReadOnly" + role = aws_iam_role.ecs_task_exec_role.name + policy_arn = "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryReadOnly" } # Define the IAM Role for ECS Task Execution with Kinesis Permissions resource "aws_iam_policy" "ecs_task_exec_policy" { - name = "${local.short_prefix}-ecs-task-exec-policy" + name = "${local.short_prefix}-ecs-task-exec-policy" policy = jsonencode({ Version = "2012-10-17", Statement = [ { - Effect = "Allow", - Action = [ + Effect = "Allow", + Action = [ "logs:CreateLogGroup", "logs:CreateLogStream", "logs:PutLogEvents" @@ -94,8 +94,8 @@ resource "aws_iam_policy" "ecs_task_exec_policy" { Resource = "arn:aws:logs:${var.aws_region}:${local.local_account_id}:log-group:/aws/vendedlogs/ecs/${local.short_prefix}-processor-task:*" }, { - Effect = "Allow", - Action = [ + Effect = "Allow", + Action = [ "s3:GetObject", "s3:ListBucket", "s3:PutObject" @@ -108,32 +108,40 @@ resource "aws_iam_policy" "ecs_task_exec_policy" { ] }, { - Effect = "Allow" + Effect = "Allow" Action = [ "kms:Encrypt", "kms:Decrypt", "kms:GenerateDataKey*" ] Resource = [ - data.aws_kms_key.existing_s3_encryption_key.arn, - data.aws_kms_key.existing_kinesis_encryption_key.arn + data.aws_kms_key.existing_s3_encryption_key.arn, + data.aws_kms_key.existing_kinesis_encryption_key.arn ] }, { - Effect = "Allow", - Action = [ + Effect = "Allow", + Action = [ "kinesis:PutRecord", "kinesis:PutRecords" ], Resource = local.kinesis_arn }, { - Effect = "Allow", - Action = [ + Effect = "Allow", + Action = [ "ecr:GetAuthorizationToken" ], Resource = "arn:aws:ecr:${var.aws_region}:${local.local_account_id}:repository/${local.short_prefix}-processing-repo" - } + }, + { + "Effect" : "Allow", + "Action" : [ + "firehose:PutRecord", + "firehose:PutRecordBatch" + ], + "Resource" : "arn:aws:firehose:*:*:deliverystream/${module.splunk.firehose_stream_name}" + } ] }) } @@ -144,7 +152,7 @@ resource "aws_iam_role_policy_attachment" "ecs_task_exec_policy_attachment" { } resource "aws_cloudwatch_log_group" "ecs_task_log_group" { - name = "/aws/vendedlogs/ecs/${local.short_prefix}-processor-task" + name = "/aws/vendedlogs/ecs/${local.short_prefix}-processor-task" } # Create the ECS Task Definition @@ -155,11 +163,11 @@ resource "aws_ecs_task_definition" "ecs_task" { cpu = "8192" memory = "24576" runtime_platform { - operating_system_family = "LINUX" - cpu_architecture = "X86_64" - } - task_role_arn = aws_iam_role.ecs_task_exec_role.arn - execution_role_arn = aws_iam_role.ecs_task_exec_role.arn + operating_system_family = "LINUX" + cpu_architecture = "X86_64" + } + task_role_arn = aws_iam_role.ecs_task_exec_role.arn + execution_role_arn = aws_iam_role.ecs_task_exec_role.arn container_definitions = jsonencode([{ name = "${local.short_prefix}-process-records-container" @@ -181,7 +189,11 @@ resource "aws_ecs_task_definition" "ecs_task" { { name = "KINESIS_STREAM_NAME" value = "${local.short_prefix}-processingdata-stream" - } + }, + { + name = "SPLUNK_FIREHOSE_NAME" + value = module.splunk.firehose_stream_name + } ] logConfiguration = { logDriver = "awslogs" @@ -197,28 +209,28 @@ resource "aws_ecs_task_definition" "ecs_task" { # IAM Role for EventBridge Pipe resource "aws_iam_role" "fifo_pipe_role" { -name = "${local.short_prefix}-eventbridge-pipe-role" -assume_role_policy = jsonencode({ - Version = "2012-10-17" - Statement = [ - { - Action = "sts:AssumeRole" - Effect = "Allow" - Principal = { - Service = "pipes.amazonaws.com" - } - } - ] -}) + name = "${local.short_prefix}-eventbridge-pipe-role" + assume_role_policy = jsonencode({ + Version = "2012-10-17" + Statement = [ + { + Action = "sts:AssumeRole" + Effect = "Allow" + Principal = { + Service = "pipes.amazonaws.com" + } + } + ] + }) } resource "aws_iam_policy" "fifo_pipe_policy" { - name = "${local.short_prefix}-fifo-pipe-policy" + name = "${local.short_prefix}-fifo-pipe-policy" policy = jsonencode({ Version = "2012-10-17", Statement = [ { - Effect = "Allow", - Action = [ + Effect = "Allow", + Action = [ "pipes:CreatePipe", "pipes:StartPipe", "pipes:StopPipe", @@ -253,8 +265,8 @@ resource "aws_iam_policy" "fifo_pipe_policy" { ] }, { - Effect = "Allow", - Action = [ + Effect = "Allow", + Action = [ "iam:PassRole" ], Resource = aws_iam_role.ecs_task_exec_role.arn @@ -263,46 +275,46 @@ resource "aws_iam_policy" "fifo_pipe_policy" { }) } - resource "aws_iam_role_policy_attachment" "fifo_pipe_policy_attachment" { - role = aws_iam_role.fifo_pipe_role.name - policy_arn = aws_iam_policy.fifo_pipe_policy.arn - } +resource "aws_iam_role_policy_attachment" "fifo_pipe_policy_attachment" { + role = aws_iam_role.fifo_pipe_role.name + policy_arn = aws_iam_policy.fifo_pipe_policy.arn +} + - # EventBridge Pipe resource "aws_pipes_pipe" "fifo_pipe" { - name = "${local.short_prefix}-pipe" - role_arn = aws_iam_role.fifo_pipe_role.arn - source = aws_sqs_queue.supplier_fifo_queue.arn - target = aws_ecs_cluster.ecs_cluster.arn - + name = "${local.short_prefix}-pipe" + role_arn = aws_iam_role.fifo_pipe_role.arn + source = aws_sqs_queue.supplier_fifo_queue.arn + target = aws_ecs_cluster.ecs_cluster.arn + target_parameters { ecs_task_parameters { task_definition_arn = aws_ecs_task_definition.ecs_task.arn launch_type = "FARGATE" network_configuration { aws_vpc_configuration { - subnets = data.aws_subnets.default.ids + subnets = data.aws_subnets.default.ids assign_public_ip = "ENABLED" } } overrides { container_override { - cpu = 2048 + cpu = 2048 name = "${local.short_prefix}-process-records-container" environment { name = "EVENT_DETAILS" value = "$.body" } - memory = 8192 + memory = 8192 memory_reservation = 1024 } } task_count = 1 } - - } - log_configuration { + + } + log_configuration { include_execution_data = ["ALL"] level = "ERROR" cloudwatch_logs_log_destination {