From a60d105c103516a169d190ee066b9fd982e70853 Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Tue, 19 Nov 2024 03:14:06 -0500 Subject: [PATCH] Use `S3CopyObjectOperator` in `example_comprehend_document_classifier` (#44160) --- .../airflow/providers/amazon/aws/hooks/s3.py | 5 +++++ .../providers/amazon/aws/operators/s3.py | 5 +++++ .../example_comprehend_document_classifier.py | 19 +++++++++++++++++-- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/hooks/s3.py b/providers/src/airflow/providers/amazon/aws/hooks/s3.py index 18405ca17d7b1..5f9fe62d3e9d6 100644 --- a/providers/src/airflow/providers/amazon/aws/hooks/s3.py +++ b/providers/src/airflow/providers/amazon/aws/hooks/s3.py @@ -1297,6 +1297,7 @@ def copy_object( dest_bucket_name: str | None = None, source_version_id: str | None = None, acl_policy: str | None = None, + meta_data_directive: str | None = None, **kwargs, ) -> None: """ @@ -1326,10 +1327,14 @@ def copy_object( :param source_version_id: Version ID of the source object (OPTIONAL) :param acl_policy: The string to specify the canned ACL policy for the object to be copied which is private by default. + :param meta_data_directive: Whether to `COPY` the metadata from the source object or `REPLACE` it + with metadata that's provided in the request. """ acl_policy = acl_policy or "private" if acl_policy != NO_ACL: kwargs["ACL"] = acl_policy + if meta_data_directive: + kwargs["MetadataDirective"] = meta_data_directive dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key( dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key" diff --git a/providers/src/airflow/providers/amazon/aws/operators/s3.py b/providers/src/airflow/providers/amazon/aws/operators/s3.py index 998c7a81065dc..6ab2ba6750892 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/s3.py +++ b/providers/src/airflow/providers/amazon/aws/operators/s3.py @@ -282,6 +282,8 @@ class S3CopyObjectOperator(BaseOperator): CA cert bundle than the one used by botocore. :param acl_policy: String specifying the canned ACL policy for the file being uploaded to the S3 bucket. + :param meta_data_directive: Whether to `COPY` the metadata from the source object or `REPLACE` it with + metadata that's provided in the request. """ template_fields: Sequence[str] = ( @@ -302,6 +304,7 @@ def __init__( aws_conn_id: str | None = "aws_default", verify: str | bool | None = None, acl_policy: str | None = None, + meta_data_directive: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -314,6 +317,7 @@ def __init__( self.aws_conn_id = aws_conn_id self.verify = verify self.acl_policy = acl_policy + self.meta_data_directive = meta_data_directive def execute(self, context: Context): s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) @@ -324,6 +328,7 @@ def execute(self, context: Context): self.dest_bucket_name, self.source_version_id, self.acl_policy, + self.meta_data_directive, ) def get_openlineage_facets_on_start(self): diff --git a/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py b/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py index 160e3f3cadf84..4a103a9265372 100644 --- a/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py +++ b/providers/tests/system/amazon/aws/example_comprehend_document_classifier.py @@ -28,6 +28,7 @@ ComprehendCreateDocumentClassifierOperator, ) from airflow.providers.amazon.aws.operators.s3 import ( + S3CopyObjectOperator, S3CreateBucketOperator, S3CreateObjectOperator, S3DeleteBucketOperator, @@ -140,7 +141,14 @@ def copy_data_to_s3(bucket: str, sources: list[dict], prefix: str, number_of_cop http_to_s3_configs = [ { "endpoint": source["endpoint"], - "s3_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-{counter}{os.path.splitext(os.path.basename(source['fileName']))[1]}", + "s3_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-0{os.path.splitext(os.path.basename(source['fileName']))[1]}", + } + for source in sources + ] + copy_to_s3_configs = [ + { + "source_bucket_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-0{os.path.splitext(os.path.basename(source['fileName']))[1]}", + "dest_bucket_key": f"{prefix}/{os.path.splitext(os.path.basename(source['fileName']))[0]}-{counter}{os.path.splitext(os.path.basename(source['fileName']))[1]}", } for counter in range(number_of_copies) for source in sources @@ -170,7 +178,14 @@ def delete_connection(conn_id): s3_bucket=bucket, ).expand_kwargs(http_to_s3_configs) - chain(create_connection(http_conn_id), http_to_s3_task, delete_connection(http_conn_id)) + s3_copy_task = S3CopyObjectOperator.partial( + task_id="s3_copy_task", + source_bucket_name=bucket, + dest_bucket_name=bucket, + meta_data_directive="REPLACE", + ).expand_kwargs(copy_to_s3_configs) + + chain(create_connection(http_conn_id), http_to_s3_task, s3_copy_task, delete_connection(http_conn_id)) with DAG(