diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 885fd60ae007a..29ad6ac4386e7 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -447,6 +447,7 @@ def upload( chunk_size: Optional[int] = None, timeout: Optional[int] = DEFAULT_TIMEOUT, num_max_attempts: int = 1, + metadata: Optional[dict] = None, ) -> None: """ Uploads a local file or file data as string or bytes to Google Cloud Storage. @@ -461,6 +462,7 @@ def upload( :param chunk_size: Blob chunk size. :param timeout: Request timeout in seconds. :param num_max_attempts: Number of attempts to try to upload the file. + :param metadata: The metadata to be uploaded with the file. """ def _call_with_retry(f: Callable[[], None]) -> None: @@ -493,6 +495,10 @@ def _call_with_retry(f: Callable[[], None]) -> None: client = self.get_conn() bucket = client.bucket(bucket_name) blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size) + + if metadata: + blob.metadata = metadata + if filename and data: raise ValueError( "'filename' and 'data' parameter provided. Please " diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index 8e12366f916a2..35df6247ef927 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -789,15 +789,21 @@ def tearDown(self): def test_upload_file(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' + metadata = {'key1': 'val1', 'key2': 'key2'} - upload_method = mock_service.return_value.bucket.return_value.blob.return_value.upload_from_filename + bucket_mock = mock_service.return_value.bucket + blob_object = bucket_mock.return_value.blob - self.gcs_hook.upload(test_bucket, test_object, filename=self.testfile.name) + upload_method = blob_object.return_value.upload_from_filename + + self.gcs_hook.upload(test_bucket, test_object, filename=self.testfile.name, metadata=metadata) upload_method.assert_called_once_with( filename=self.testfile.name, content_type='application/octet-stream', timeout=60 ) + self.assertEqual(metadata, blob_object.return_value.metadata) + @mock.patch(GCS_STRING.format('GCSHook.get_conn')) def test_upload_file_gzip(self, mock_service): test_bucket = 'test_bucket'