diff --git a/pyproject.toml b/pyproject.toml index 387044c01..c3fd51c4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ tests = [ "pytest-sugar>=0.9.6", "pytest-cov>=4.1.0", "pytest-mock>=3.12.0", - "pytest-servers[all]>=0.5.8", + "pytest-servers[all]>=0.5.9", "pytest-benchmark[histogram]", "pytest-xdist>=3.3.1", "virtualenv", diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index a3c16503b..0c678cba6 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -1236,10 +1236,16 @@ def ls_dataset_rows( return q.to_db_records() - def signed_url(self, source: str, path: str, client_config=None) -> str: + def signed_url( + self, + source: str, + path: str, + version_id: Optional[str] = None, + client_config=None, + ) -> str: client_config = client_config or self.client_config client = Client.get_client(source, self.cache, **client_config) - return client.url(path) + return client.url(path, version_id=version_id) def export_dataset_table( self, diff --git a/src/datachain/client/azure.py b/src/datachain/client/azure.py index 4421945c6..ed00fda75 100644 --- a/src/datachain/client/azure.py +++ b/src/datachain/client/azure.py @@ -1,4 +1,5 @@ -from typing import Any +from typing import Any, Optional +from urllib.parse import parse_qs, urlsplit, urlunsplit from adlfs import AzureBlobFileSystem from tqdm import tqdm @@ -25,6 +26,16 @@ def info_to_file(self, v: dict[str, Any], path: str) -> File: size=v.get("size", ""), ) + def url(self, path: str, expires: int = 3600, **kwargs) -> str: + """ + Generate a signed URL for the given path. + """ + version_id = kwargs.pop("version_id", None) + result = self.fs.sign( + self.get_full_path(path, version_id), expiration=expires, **kwargs + ) + return result + (f"&versionid={version_id}" if version_id else "") + async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None: prefix = start_prefix if prefix: @@ -57,4 +68,13 @@ async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> Non finally: result_queue.put_nowait(None) + @classmethod + def version_path(cls, path: str, version_id: Optional[str]) -> str: + parts = list(urlsplit(path)) + query = parse_qs(parts[3]) + if "versionid" in query: + raise ValueError("path already includes a version query") + parts[3] = f"versionid={version_id}" if version_id else "" + return urlunsplit(parts) + _fetch_default = _fetch_flat diff --git a/src/datachain/client/fsspec.py b/src/datachain/client/fsspec.py index 43d16a374..c1baa6afe 100644 --- a/src/datachain/client/fsspec.py +++ b/src/datachain/client/fsspec.py @@ -202,7 +202,11 @@ def fs(self) -> "AbstractFileSystem": return self._fs def url(self, path: str, expires: int = 3600, **kwargs) -> str: - return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs) + return self.fs.sign( + self.get_full_path(path, kwargs.pop("version_id", None)), + expiration=expires, + **kwargs, + ) async def get_current_etag(self, file: "File") -> str: kwargs = {} diff --git a/src/datachain/client/gcs.py b/src/datachain/client/gcs.py index 2e3981b8b..f7f9907d5 100644 --- a/src/datachain/client/gcs.py +++ b/src/datachain/client/gcs.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from datetime import datetime from typing import Any, Optional, cast -from urllib.parse import urlsplit from dateutil.parser import isoparse from gcsfs import GCSFileSystem @@ -39,9 +38,13 @@ def url(self, path: str, expires: int = 3600, **kwargs) -> str: If the client is anonymous, a public URL is returned instead (see https://cloud.google.com/storage/docs/access-public-data#api-link). """ + version_id = kwargs.pop("version_id", None) if self.fs.storage_options.get("token") == "anon": - return f"https://storage.googleapis.com/{self.name}/{path}" - return self.fs.sign(self.get_full_path(path), expiration=expires, **kwargs) + query = f"?generation={version_id}" if version_id else "" + return f"https://storage.googleapis.com/{self.name}/{path}{query}" + return self.fs.sign( + self.get_full_path(path, version_id), expiration=expires, **kwargs + ) @staticmethod def parse_timestamp(timestamp: str) -> datetime: @@ -133,25 +136,6 @@ def info_to_file(self, v: dict[str, Any], path: str) -> File: size=v.get("size", ""), ) - @classmethod - def _split_version(cls, path: str) -> tuple[str, Optional[str]]: - parts = list(urlsplit(path)) - scheme = parts[0] - parts = GCSFileSystem._split_path( # pylint: disable=protected-access - path, version_aware=True - ) - bucket, key, generation = parts - scheme = f"{scheme}://" if scheme else "" - return f"{scheme}{bucket}/{key}", generation - - @classmethod - def _join_version(cls, path: str, version_id: Optional[str]) -> str: - path, path_version = cls._split_version(path) - if path_version: - raise ValueError("path already includes an object generation") - return f"{path}#{version_id}" if version_id else path - @classmethod def version_path(cls, path: str, version_id: Optional[str]) -> str: - path, _ = cls._split_version(path) - return cls._join_version(path, version_id) + return f"{path}#{version_id}" if version_id else path diff --git a/src/datachain/client/s3.py b/src/datachain/client/s3.py index 37de24442..ac18e81b1 100644 --- a/src/datachain/client/s3.py +++ b/src/datachain/client/s3.py @@ -1,5 +1,6 @@ import asyncio from typing import Any, Optional, cast +from urllib.parse import parse_qs, urlsplit, urlunsplit from botocore.exceptions import NoCredentialsError from s3fs import S3FileSystem @@ -121,6 +122,15 @@ def _entry_from_boto(self, v, bucket, versions=False) -> File: size=v["Size"], ) + @classmethod + def version_path(cls, path: str, version_id: Optional[str]) -> str: + parts = list(urlsplit(path)) + query = parse_qs(parts[3]) + if "versionId" in query: + raise ValueError("path already includes a version query") + parts[3] = f"versionId={version_id}" if version_id else "" + return urlunsplit(parts) + async def _fetch_dir( self, prefix, diff --git a/tests/func/fake-service-account-credentials.json b/tests/func/fake-service-account-credentials.json new file mode 100644 index 000000000..b830d94b6 --- /dev/null +++ b/tests/func/fake-service-account-credentials.json @@ -0,0 +1,9 @@ +{ + "type": "service_account", + "project_id": "gcsfs", + "private_key_id": "84e3fd6d7101ec632e7348e8940b2aca71133e71", + "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDAJWz1KlBu2jRE\nlUahHKuJes34hj4pr8ADhgejpAguBBrubXVvSro7aSSbvyDC/GIcyDQ8Q33YK/kT\nufQvCez7iIACbtP53o6WjcrIAP+l8z9RUL9so+sBCaVRZzh74+cEMfWIbc3ACBB5\nU2BPBWQFtr3Qtbe8TUJ+liNcLb8I2JznfydHvl9cn0/50HeOB99Xho5JAY75aE0Y\nT+/aMTFlr/kUbekLRRi4pyE+uOA/ei5RmfwzqO366YLMtEC2DaHwTqSuxBWnbtTW\nu/OvYpmPHazd6own2zJLQ0Elnm5WC/d9YmxhHi/8pJFkkbVf/2CYWEBbmBI3ZOx3\n/nHQwcIPAgMBAAECggEAUztC/dYE/me10WmKLTrykTxpYTihT8RqG/ygbYGd63Tq\nx5IRlxJbJmYOrgp2IhBaXZZZjis8JXoyzBk2TXPyvChuLt+cIfYGdO/ZwZYxJ0z9\nhfdA3EoK/6mSe3cHcB8SEG6lqaHKyN6VaEC2DLTMlW8JvREiFEaxQY0+puzH/ge4\n2EypCP4pvlveH78EIIipPgWcJYGpv0bv8KErECuVHRjJv6vZqUjQdcIi73mCz/5u\nnQqLY8j9lOuCr9vBis7DZIyY2tn4vfqcqxfH9wuIFXnzIQW6Wyg0+bBQydHg1kJ2\nFOszfkBVxZ6LpcHGB4CV4c5z7Me2cMReXQz6VsyoLQKBgQD9v92rHZYDBy4/vGxx\nbpfUkAlcCGW8GXu+qsdmyhZdjSdjDLY6lav+6UoHIJgmnA7LsKPFgnEDrdn78KBb\n3wno3VHfozL5kF887q9hC/+UurwScCKIw5QkmWtsStVgjr6wPmAu6rspMz5xNjaa\nSU4YzlNcbBUUXUawhXytWPR+OwKBgQDB2bDCD00R2yfYFdjAKapqenOtMvrnihUi\nW9Se7Yizme7s25fDxF5CBPpOdKPU2EZUlqBC/5182oMUP/xYUOHJkuUhbYcvU0qr\n+BQewLwr6rs+O1QPTh/6e70SUFR+YJLaAHkDc6fvcdjtl+Zx/p02Zj+UiW3/D4Jj\nc0EqVr4qPQKBgQCbJx3a6xQ2dcWJoySLlxuvFQMkCt5pzQsk4jdaWmaifRSAM92Y\npLut+ecRxJRDx1gko7T/p2qC3WJT8iWbBx2ADRNqstcQUX5qO2dw5202+5bTj00O\nYsfKOSS96mPdzmo6SWl2RoB6CKM9hfCNFhVyhXXjJRMeiIoYlQZO1/1m0QKBgCzz\nat6FJ8z1MdcUsc9VmhPY00wdXzsjtOTjwHkeAa4MCvBXt2iI94Z9mwFoYLkxcZWZ\n3A3NMlrKXMzsTXq5PrI8Yu+Oc2OQ/+bCvv+ml7vjUYoLveFSr22pFd3STNWFVWhB\n5c3cGtwWXUQzDhfu/8umiCXMfHpBwW2IQ1srBCvNAoGATcC3oCFBC/HdGxdeJC5C\n59EoFvKdZsAdc2I5GS/DtZ1Wo9sXqubCaiUDz+4yty+ssHIZ1ikFr8rWfL6KFEs2\niTe+kgM/9FLFtftf1WDpbfIOumbz/6CiGLqsGNlO3ZaU0kYJ041SZ8RleTOYa0zO\noSTLwBo3vje+aflytEwS8SI=\n-----END PRIVATE KEY-----", + "client_email": "fake@gscfs.iam.gserviceaccount.com", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token" + } diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index b8e7e87c5..58c37c7de 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -3,6 +3,7 @@ from urllib.parse import urlparse import pytest +import requests import yaml from fsspec.implementations.local import LocalFileSystem @@ -993,3 +994,61 @@ def test_garbage_collect(cloud_test_catalog, from_cli, capsys): else: catalog.cleanup_tables(temp_tables) assert catalog.get_temp_table_names() == [] + + +@pytest.fixture +def gcs_fake_credentials(monkeypatch): + # For signed URL tests to work we need to setup some fake credentials + # that looks like real ones + monkeypatch.setenv( + "GOOGLE_APPLICATION_CREDENTIALS", + os.path.dirname(__file__) + "/fake-service-account-credentials.json", + ) + + +@pytest.mark.parametrize("tree", [{"test-signed-file": "original"}], indirect=True) +@pytest.mark.parametrize( + "cloud_type, version_aware", + (["s3", False], ["azure", False], ["gs", False]), + indirect=True, +) +def test_signed_url(cloud_test_catalog, gcs_fake_credentials): + signed_url = cloud_test_catalog.catalog.signed_url( + cloud_test_catalog.src_uri, "test-signed-file" + ) + content = requests.get(signed_url, timeout=10).text + assert content == "original" + + +@pytest.mark.parametrize( + "tree", [{"test-signed-file-versioned": "original"}], indirect=True +) +@pytest.mark.parametrize( + "cloud_type, version_aware", + (["s3", True], ["azure", True], ["gs", True]), + indirect=True, +) +def test_signed_url_versioned(cloud_test_catalog, gcs_fake_credentials): + file_name = "test-signed-file-versioned" + src_uri = cloud_test_catalog.src_uri + catalog = cloud_test_catalog.catalog + client = catalog.get_client(src_uri) + + original_version = client.get_file_info(file_name).version + + (cloud_test_catalog.src / file_name).write_text("modified") + + modified_version = client.get_file_info(file_name).version + + for version, expected in [ + (original_version, "original"), + (modified_version, "modified"), + ]: + signed_url = catalog.signed_url( + src_uri, + file_name, + version_id=version, + ) + + content = requests.get(signed_url, timeout=10).text + assert content == expected diff --git a/tests/unit/test_client_gcs.py b/tests/unit/test_client_gcs.py index 8c80db786..9365ff4c6 100644 --- a/tests/unit/test_client_gcs.py +++ b/tests/unit/test_client_gcs.py @@ -4,3 +4,11 @@ def test_anon_url(): client = Client.get_client("gs://foo", None, anon=True) assert client.url("bar") == "https://storage.googleapis.com/foo/bar" + + +def test_anon_versioned_url(): + client = Client.get_client("gs://foo", None, anon=True) + assert ( + client.url("bar", version_id="1234566") + == "https://storage.googleapis.com/foo/bar?generation=1234566" + ) diff --git a/tests/unit/test_client_s3.py b/tests/unit/test_client_s3.py index 3f72d3353..9b3d3f47a 100644 --- a/tests/unit/test_client_s3.py +++ b/tests/unit/test_client_s3.py @@ -1,5 +1,6 @@ import pytest +from datachain.client.s3 import ClientS3 from datachain.node import DirType, Node from datachain.nodes_thread_pool import NodeChunk @@ -77,3 +78,8 @@ def test_node_bucket_full_split(nodes): assert len(bkt[1]) == 1 assert len(bkt[2]) == 1 assert len(bkt[3]) == 1 + + +def test_version_path_already_has_version(): + with pytest.raises(ValueError): + ClientS3.version_path("s3://foo/bar?versionId=123", "456")