diff --git a/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py b/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py index 4a103a9265372..b0bf41209785c 100644 --- a/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py +++ b/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py @@ -16,12 +16,10 @@ # under the License. from __future__ import annotations -import os from datetime import datetime -from airflow import DAG, settings +from airflow import DAG from airflow.decorators import task, task_group -from airflow.models import Connection from airflow.models.baseoperator import chain from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook from airflow.providers.amazon.aws.operators.comprehend import ( @@ -36,31 +34,27 @@ from airflow.providers.amazon.aws.sensors.comprehend import ( ComprehendCreateDocumentClassifierCompletedSensor, ) -from airflow.providers.amazon.aws.transfers.http_to_s3 import HttpToS3Operator from airflow.utils.trigger_rule import TriggerRule from providers.tests.system.amazon.aws.utils import SystemTestContextBuilder ROLE_ARN_KEY = "ROLE_ARN" -sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build() +BUCKET_NAME_KEY = "BUCKET_NAME" +BUCKET_KEY_DISCHARGE_KEY = "BUCKET_KEY_DISCHARGE" +BUCKET_KEY_DOCTORS_NOTES = "BUCKET_KEY_DOCTORS_NOTES" +sys_test_context_task = ( + SystemTestContextBuilder() + .add_variable(ROLE_ARN_KEY) + .add_variable(BUCKET_NAME_KEY) + .add_variable(BUCKET_KEY_DISCHARGE_KEY) + .add_variable(BUCKET_KEY_DOCTORS_NOTES) + .build() +) DAG_ID = "example_comprehend_document_classifier" ANNOTATION_BUCKET_KEY = "training-labels/label.csv" TRAINING_DATA_PREFIX = "training-docs" -# To create a custom document classifier, we need a minimum of 10 documents for each label. -# for testing purpose, we will generate 10 copies of each document referenced below. -PUBLIC_DATA_SOURCES = [ - { - "fileName": "discharge-summary.pdf", - "endpoint": "aws-samples/amazon-comprehend-examples/blob/master/building-custom-classifier/sample-docs/discharge-summary.pdf?raw=true", - }, - { - "fileName": "doctors-notes.pdf", - "endpoint": "aws-samples/amazon-comprehend-examples/blob/master/building-custom-classifier/sample-docs/doctors-notes.pdf?raw=true", - }, -] - # Annotations file won't allow headers # label,document name,page number @@ -119,74 +113,27 @@ def delete_classifier(document_classifier_arn: str): ) -@task_group -def copy_data_to_s3(bucket: str, sources: list[dict], prefix: str, number_of_copies=1): - """ - - Copy some sample data to S3 using HttpToS3Operator. - - :param bucket: Name of the Amazon S3 bucket to send the data. - :param prefix: Folder to store the files - :param number_of_copies: Number of files to create for a document from the sources - :param sources: Public available data locations - """ - - """ - EX: If number_of_copies is 2, sources has file name 'file.pdf', and prefix is 'training-docs'. - Will generate two copies and upload to s3: - - training-docs/file-0.pdf - - training-docs/file-1.pdf - """ - - http_to_s3_configs = [ +@task +def create_kwargs_discharge(): + return [ { - "endpoint": source["endpoint"], - "s3_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-0{os.path.splitext(os.path.basename(source['fileName']))[1]}", + "source_bucket_key": str(test_context[BUCKET_KEY_DISCHARGE_KEY]), + "dest_bucket_key": f"{TRAINING_DATA_PREFIX}/discharge-summary-{counter}.pdf", } - for source in sources + for counter in range(10) ] - copy_to_s3_configs = [ + + +@task +def create_kwargs_doctors_notes(): + return [ { - "source_bucket_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-0{os.path.splitext(os.path.basename(source['fileName']))[1]}", - "dest_bucket_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-{counter}{os.path.splitext(os.path.basename(source['fileName']))[1]}", + "source_bucket_key": str(test_context[BUCKET_KEY_DOCTORS_NOTES]), + "dest_bucket_key": f"{TRAINING_DATA_PREFIX}/doctors-notes-{counter}.pdf", } - for counter in range(number_of_copies) - for source in sources + for counter in range(10) ] - @task - def create_connection(conn_id): - conn = Connection( - conn_id=conn_id, - conn_type="http", - host="https://github.com/", - ) - session = settings.Session() - session.add(conn) - session.commit() - - @task(trigger_rule=TriggerRule.ALL_DONE) - def delete_connection(conn_id): - session = settings.Session() - conn_to_details = session.query(Connection).filter(Connection.conn_id == conn_id).first() - session.delete(conn_to_details) - session.commit() - - http_to_s3_task = HttpToS3Operator.partial( - task_id="http_to_s3_task", - http_conn_id=http_conn_id, - s3_bucket=bucket, - ).expand_kwargs(http_to_s3_configs) - - s3_copy_task = S3CopyObjectOperator.partial( - task_id="s3_copy_task", - source_bucket_name=bucket, - dest_bucket_name=bucket, - meta_data_directive="REPLACE", - ).expand_kwargs(copy_to_s3_configs) - - chain(create_connection(http_conn_id), http_to_s3_task, s3_copy_task, delete_connection(http_conn_id)) - with DAG( dag_id=DAG_ID, @@ -199,7 +146,6 @@ def delete_connection(conn_id): env_id = test_context["ENV_ID"] classifier_name = f"{env_id}-custom-document-classifier" bucket_name = f"{env_id}-comprehend-document-classifier" - http_conn_id = f"{env_id}-git" input_data_configurations = { "S3Uri": f"s3://{bucket_name}/{ANNOTATION_BUCKET_KEY}", @@ -219,6 +165,22 @@ def delete_connection(conn_id): bucket_name=bucket_name, ) + discharge_kwargs = create_kwargs_discharge() + s3_copy_discharge_task = S3CopyObjectOperator.partial( + task_id="s3_copy_discharge_task", + source_bucket_name=test_context[BUCKET_NAME_KEY], + dest_bucket_name=bucket_name, + meta_data_directive="REPLACE", + ).expand_kwargs(discharge_kwargs) + + doctors_notes_kwargs = create_kwargs_doctors_notes() + s3_copy_doctors_notes_task = S3CopyObjectOperator.partial( + task_id="s3_copy_doctors_notes_task", + source_bucket_name=test_context[BUCKET_NAME_KEY], + dest_bucket_name=bucket_name, + meta_data_directive="REPLACE", + ).expand_kwargs(doctors_notes_kwargs) + upload_annotation_file = S3CreateObjectOperator( task_id="upload_annotation_file", s3_bucket=bucket_name, @@ -236,10 +198,9 @@ def delete_connection(conn_id): chain( test_context, create_bucket, + s3_copy_discharge_task, + s3_copy_doctors_notes_task, upload_annotation_file, - copy_data_to_s3( - bucket=bucket_name, sources=PUBLIC_DATA_SOURCES, prefix=TRAINING_DATA_PREFIX, number_of_copies=10 - ), # TEST BODY document_classifier_workflow(), # TEST TEARDOWN