diff --git a/providers/src/airflow/providers/google/cloud/openlineage/utils.py b/providers/src/airflow/providers/google/cloud/openlineage/utils.py index ff6c4c05bb6a4..6b8c93063fc39 100644 --- a/providers/src/airflow/providers/google/cloud/openlineage/utils.py +++ b/providers/src/airflow/providers/google/cloud/openlineage/utils.py @@ -54,6 +54,24 @@ def extract_ds_name_from_gcs_path(path: str) -> str: Returns: The processed dataset name. + + Examples: + >>> extract_ds_name_from_gcs_path("/dir/file.*") + 'dir' + >>> extract_ds_name_from_gcs_path("/dir/pre_") + 'dir' + >>> extract_ds_name_from_gcs_path("/dir/file.txt") + 'dir/file.txt' + >>> extract_ds_name_from_gcs_path("/dir/file.") + 'dir' + >>> extract_ds_name_from_gcs_path("/dir/") + 'dir' + >>> extract_ds_name_from_gcs_path("") + '/' + >>> extract_ds_name_from_gcs_path("/") + '/' + >>> extract_ds_name_from_gcs_path(".") + '/' """ if WILDCARD in path: path = path.split(WILDCARD, maxsplit=1)[0] diff --git a/providers/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py b/providers/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py index cbd82724ede67..18687f1eb93d1 100644 --- a/providers/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py +++ b/providers/src/airflow/providers/google/cloud/transfers/s3_to_gcs.py @@ -345,3 +345,18 @@ def get_transfer_hook(self): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.google_impersonation_chain, ) + + def get_openlineage_facets_on_start(self): + from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url + from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path + from airflow.providers.openlineage.extractors import OperatorLineage + + gcs_bucket, gcs_blob = _parse_gcs_url(self.dest_gcs) + if not self.apply_gcs_prefix: + gcs_blob += self.prefix + + return OperatorLineage( + inputs=[Dataset(namespace=f"s3://{self.bucket}", name=self.prefix.strip("/") or "/")], + outputs=[Dataset(namespace=f"gs://{gcs_bucket}", name=extract_ds_name_from_gcs_path(gcs_blob))], + ) diff --git a/providers/tests/google/cloud/transfers/test_local_to_gcs.py b/providers/tests/google/cloud/transfers/test_local_to_gcs.py index 0ebf2f595032c..bf0ad43e1108d 100644 --- a/providers/tests/google/cloud/transfers/test_local_to_gcs.py +++ b/providers/tests/google/cloud/transfers/test_local_to_gcs.py @@ -195,7 +195,9 @@ def test_get_openlineage_facets_on_start_with_string_src( assert not result.run_facets assert len(result.outputs) == 1 assert len(result.inputs) == 1 + assert result.outputs[0].namespace == "gs://dummy" assert result.outputs[0].name == expected_output + assert result.inputs[0].namespace == "file" assert result.inputs[0].name == expected_input if symlink: assert result.inputs[0].facets["symlink"] == SymlinksDatasetFacet( diff --git a/providers/tests/google/cloud/transfers/test_s3_to_gcs.py b/providers/tests/google/cloud/transfers/test_s3_to_gcs.py index 896d281d32891..9539e257aff71 100644 --- a/providers/tests/google/cloud/transfers/test_s3_to_gcs.py +++ b/providers/tests/google/cloud/transfers/test_s3_to_gcs.py @@ -270,6 +270,38 @@ def test_execute_apply_gcs_prefix( assert sorted([s3_prefix + s3_object]) == sorted(uploaded_files) + @pytest.mark.parametrize( + ("s3_prefix", "gcs_destination", "apply_gcs_prefix", "expected_input", "expected_output"), + [ + ("dir/pre", "gs://bucket/dest_dir/", False, "dir/pre", "dest_dir/dir"), + ("dir/pre/", "gs://bucket/dest_dir/", False, "dir/pre", "dest_dir/dir/pre"), + ("dir/pre", "gs://bucket/dest_dir/", True, "dir/pre", "dest_dir"), + ("dir/pre", "gs://bucket/", False, "dir/pre", "dir"), + ("dir/pre", "gs://bucket/", True, "dir/pre", "/"), + ("", "gs://bucket/", False, "/", "/"), + ("", "gs://bucket/", True, "/", "/"), + ], + ) + def test_get_openlineage_facets_on_start( + self, s3_prefix, gcs_destination, apply_gcs_prefix, expected_input, expected_output + ): + operator = S3ToGCSOperator( + task_id=TASK_ID, + bucket=S3_BUCKET, + prefix=s3_prefix, + dest_gcs=gcs_destination, + apply_gcs_prefix=apply_gcs_prefix, + ) + result = operator.get_openlineage_facets_on_start() + assert not result.job_facets + assert not result.run_facets + assert len(result.outputs) == 1 + assert len(result.inputs) == 1 + assert result.outputs[0].namespace == "gs://bucket" + assert result.outputs[0].name == expected_output + assert result.inputs[0].namespace == f"s3://{S3_BUCKET}" + assert result.inputs[0].name == expected_input + class TestS3ToGoogleCloudStorageOperatorDeferrable: @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.CloudDataTransferServiceHook")