diff --git a/examples/get_started/udfs/parallel.py b/examples/get_started/udfs/parallel.py index 8b8f87d61..a907458a9 100644 --- a/examples/get_started/udfs/parallel.py +++ b/examples/get_started/udfs/parallel.py @@ -31,7 +31,7 @@ def path_len_benchmark(path): # Run in chain DataChain.from_storage( - path="gs://datachain-demo/dogs-and-cats/", + "gs://datachain-demo/dogs-and-cats/", ).settings(parallel=-1).map( path_len_benchmark, params=["file.path"], diff --git a/examples/get_started/udfs/simple.py b/examples/get_started/udfs/simple.py index a0aaea05b..a9af5818c 100644 --- a/examples/get_started/udfs/simple.py +++ b/examples/get_started/udfs/simple.py @@ -11,7 +11,7 @@ def path_len(path): if __name__ == "__main__": # Run in chain DataChain.from_storage( - path="gs://datachain-demo/dogs-and-cats/", + uri="gs://datachain-demo/dogs-and-cats/", ).map( path_len, params=["file.path"], diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 7f583012c..073c7f647 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -1422,17 +1422,18 @@ def get_dataset_dependencies( return direct_dependencies - def ls_datasets(self) -> Iterator[DatasetRecord]: + def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetRecord]: datasets = self.metastore.list_datasets() for d in datasets: - if not d.is_bucket_listing: + if not d.is_bucket_listing or include_listing: yield d def list_datasets_versions( self, + include_listing: bool = False, ) -> Iterator[tuple[DatasetRecord, "DatasetVersion", Optional["Job"]]]: """Iterate over all dataset versions with related jobs.""" - datasets = list(self.ls_datasets()) + datasets = list(self.ls_datasets(include_listing=include_listing)) # preselect dataset versions jobs from db to avoid multiple queries jobs_ids: set[str] = { diff --git a/src/datachain/dataset.py b/src/datachain/dataset.py index 7ac9b8b1a..4b7aa454f 100644 --- a/src/datachain/dataset.py +++ b/src/datachain/dataset.py @@ -25,6 +25,7 @@ DATASET_PREFIX = "ds://" QUERY_DATASET_PREFIX = "ds_query_" +LISTING_PREFIX = "lst__" def parse_dataset_uri(uri: str) -> tuple[str, Optional[int]]: @@ -443,7 +444,11 @@ def is_bucket_listing(self) -> bool: For bucket listing we implicitly create underlying dataset to hold data. This method is checking if this is one of those datasets. """ - return Client.is_data_source_uri(self.name) + # TODO refactor and maybe remove method in + # https://github.com/iterative/datachain/issues/318 + return Client.is_data_source_uri(self.name) or self.name.startswith( + LISTING_PREFIX + ) @property def versions_values(self) -> list[int]: diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index d34095787..6e61f2603 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -27,7 +27,15 @@ from datachain.lib.data_model import DataModel, DataType, dict_to_data_model from datachain.lib.dataset_info import DatasetInfo from datachain.lib.file import ExportPlacement as FileExportPlacement -from datachain.lib.file import File, IndexedFile, get_file +from datachain.lib.file import File, IndexedFile, get_file_type +from datachain.lib.listing import ( + is_listing_dataset, + is_listing_expired, + is_listing_subset, + list_bucket, + ls, + parse_listing_uri, +) from datachain.lib.meta_formats import read_meta, read_schema from datachain.lib.model_store import ModelStore from datachain.lib.settings import Settings @@ -311,7 +319,7 @@ def add_schema(self, signals_schema: SignalSchema) -> "Self": # noqa: D102 @classmethod def from_storage( cls, - path, + uri, *, type: Literal["binary", "text", "image"] = "binary", session: Optional[Session] = None, @@ -320,41 +328,79 @@ def from_storage( recursive: Optional[bool] = True, object_name: str = "file", update: bool = False, - **kwargs, + anon: bool = False, ) -> "Self": """Get data from a storage as a list of file with all file attributes. It returns the chain itself as usual. Parameters: - path : storage URI with directory. URI must start with storage prefix such + uri : storage URI with directory. URI must start with storage prefix such as `s3://`, `gs://`, `az://` or "file:///" type : read file as "binary", "text", or "image" data. Default is "binary". recursive : search recursively for the given path. object_name : Created object column name. update : force storage reindexing. Default is False. + anon : If True, we will treat cloud bucket as public one Example: ```py chain = DataChain.from_storage("s3://my-bucket/my-dir") ``` """ - func = get_file(type) - return ( - cls( - path, - session=session, - settings=settings, - recursive=recursive, - update=update, - in_memory=in_memory, - **kwargs, - ) - .map(**{object_name: func}) - .select(object_name) + file_type = get_file_type(type) + + if anon: + client_config = {"anon": True} + else: + client_config = None + + session = Session.get(session, client_config=client_config, in_memory=in_memory) + + list_dataset_name, list_uri, list_path = parse_listing_uri( + uri, session.catalog.cache, session.catalog.client_config ) + need_listing = True + + for ds in cls.datasets( + session=session, in_memory=in_memory, include_listing=True + ).collect("dataset"): + if ( + not is_listing_expired(ds.created_at) # type: ignore[union-attr] + and is_listing_dataset(ds.name) # type: ignore[union-attr] + and is_listing_subset(ds.name, list_dataset_name) # type: ignore[union-attr] + and not update + ): + need_listing = False + list_dataset_name = ds.name # type: ignore[union-attr] + + if need_listing: + # caching new listing to special listing dataset + ( + cls.from_records( + DataChain.DEFAULT_FILE_RECORD, + session=session, + settings=settings, + in_memory=in_memory, + ) + .gen( + list_bucket(list_uri, client_config=session.catalog.client_config), + output={f"{object_name}": File}, + ) + .save(list_dataset_name, listing=True) + ) + + dc = cls.from_dataset(list_dataset_name, session=session) + dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type}) + + return ls(dc, list_path, recursive=recursive, object_name=object_name) @classmethod - def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain": + def from_dataset( + cls, + name: str, + version: Optional[int] = None, + session: Optional[Session] = None, + ) -> "DataChain": """Get data from a saved Dataset. It returns the chain itself. Parameters: @@ -366,7 +412,7 @@ def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain": chain = DataChain.from_dataset("my_cats") ``` """ - return DataChain(name=name, version=version) + return DataChain(name=name, version=version, session=session) @classmethod def from_json( @@ -419,7 +465,7 @@ def jmespath_to_name(s: str): object_name = jmespath_to_name(jmespath) if not object_name: object_name = meta_type - chain = DataChain.from_storage(path=path, type=type, **kwargs) + chain = DataChain.from_storage(uri=path, type=type, **kwargs) signal_dict = { object_name: read_meta( schema_from=schema_from, @@ -479,7 +525,7 @@ def jmespath_to_name(s: str): object_name = jmespath_to_name(jmespath) if not object_name: object_name = meta_type - chain = DataChain.from_storage(path=path, type=type, **kwargs) + chain = DataChain.from_storage(uri=path, type=type, **kwargs) signal_dict = { object_name: read_meta( schema_from=schema_from, @@ -500,6 +546,7 @@ def datasets( settings: Optional[dict] = None, in_memory: bool = False, object_name: str = "dataset", + include_listing: bool = False, ) -> "DataChain": """Generate chain with list of registered datasets. @@ -517,7 +564,9 @@ def datasets( datasets = [ DatasetInfo.from_models(d, v, j) - for d, v, j in catalog.list_datasets_versions() + for d, v, j in catalog.list_datasets_versions( + include_listing=include_listing + ) ] return cls.from_values( @@ -570,7 +619,7 @@ def print_jsonl_schema( # type: ignore[override] ) def save( # type: ignore[override] - self, name: Optional[str] = None, version: Optional[int] = None + self, name: Optional[str] = None, version: Optional[int] = None, **kwargs ) -> "Self": """Save to a Dataset. It returns the chain itself. @@ -580,7 +629,7 @@ def save( # type: ignore[override] version : version of a dataset. Default - the last version that exist. """ schema = self.signals_schema.clone_without_sys_signals().serialize() - return super().save(name=name, version=version, feature_schema=schema) + return super().save(name=name, version=version, feature_schema=schema, **kwargs) def apply(self, func, *args, **kwargs): """Apply any function to the chain. @@ -1665,7 +1714,10 @@ def from_records( if schema: signal_schema = SignalSchema(schema) - columns = signal_schema.db_signals(as_columns=True) # type: ignore[assignment] + columns = [ + sqlalchemy.Column(c.name, c.type) # type: ignore[union-attr] + for c in signal_schema.db_signals(as_columns=True) # type: ignore[assignment] + ] else: columns = [ sqlalchemy.Column(name, typ) diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index b02fab1ed..2ee5c1400 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -349,39 +349,6 @@ def save(self, destination: str): self.read().save(destination) -def get_file(type_: Literal["binary", "text", "image"] = "binary"): - file: type[File] = File - if type_ == "text": - file = TextFile - elif type_ == "image": - file = ImageFile # type: ignore[assignment] - - def get_file_type( - source: str, - path: str, - size: int, - version: str, - etag: str, - is_latest: bool, - last_modified: datetime, - location: Optional[Union[dict, list[dict]]], - vtype: str, - ) -> file: # type: ignore[valid-type] - return file( - source=source, - path=path, - size=size, - version=version, - etag=etag, - is_latest=is_latest, - last_modified=last_modified, - location=location, - vtype=vtype, - ) - - return get_file_type - - class IndexedFile(DataModel): """Metadata indexed from tabular files. @@ -390,3 +357,13 @@ class IndexedFile(DataModel): file: File index: int + + +def get_file_type(type_: Literal["binary", "text", "image"] = "binary") -> type[File]: + file: type[File] = File + if type_ == "text": + file = TextFile + elif type_ == "image": + file = ImageFile # type: ignore[assignment] + + return file diff --git a/src/datachain/lib/listing.py b/src/datachain/lib/listing.py index 47e247a6c..b8cdf7c01 100644 --- a/src/datachain/lib/listing.py +++ b/src/datachain/lib/listing.py @@ -1,11 +1,23 @@ +import posixpath from collections.abc import Iterator -from typing import Callable +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Callable, Optional from fsspec.asyn import get_loop +from sqlalchemy.sql.expression import true from datachain.asyn import iter_over_async from datachain.client import Client from datachain.lib.file import File +from datachain.query.schema import Column +from datachain.sql.functions import path as pathfunc +from datachain.utils import uses_glob + +if TYPE_CHECKING: + from datachain.lib.dc import DataChain + +LISTING_TTL = 4 * 60 * 60 # cached listing lasts 4 hours +LISTING_PREFIX = "lst__" # listing datasets start with this name def list_bucket(uri: str, client_config=None) -> Callable: @@ -17,8 +29,84 @@ def list_bucket(uri: str, client_config=None) -> Callable: def list_func() -> Iterator[File]: config = client_config or {} client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type] - for entries in iter_over_async(client.scandir(path), get_loop()): + for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()): for entry in entries: yield entry.to_file(client.uri) return list_func + + +def ls( + dc: "DataChain", + path: str, + recursive: Optional[bool] = True, + object_name="file", +): + """ + Return files by some path from DataChain instance which contains bucket listing. + Path can have globs. + If recursive is set to False, only first level children will be returned by + specified path + """ + + def _file_c(name: str) -> Column: + return Column(f"{object_name}.{name}") + + dc = dc.filter(_file_c("is_latest") == true()) + + if recursive: + if not path or path == "/": + # root of a bucket, returning all latest files from it + return dc + + if not uses_glob(path): + # path is not glob, so it's pointing to some directory or a specific + # file and we are adding proper filter for it + return dc.filter( + (_file_c("path") == path) + | (_file_c("path").glob(path.rstrip("/") + "/*")) + ) + + # path has glob syntax so we are returning glob filter + return dc.filter(_file_c("path").glob(path)) + # returning only first level children by path + return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*")) + + +def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]: + """ + Parsing uri and returns listing dataset name, listing uri and listing path + """ + client, path = Client.parse_url(uri, cache, **client_config) + + # clean path without globs + lst_uri_path = ( + posixpath.dirname(path) if uses_glob(path) or client.fs.isfile(uri) else path + ) + + lst_uri = f"{client.uri}/{lst_uri_path.lstrip('/')}" + ds_name = ( + f"{LISTING_PREFIX}{client.uri}/{posixpath.join(lst_uri_path, '').lstrip('/')}" + ) + + return ds_name, lst_uri, path + + +def is_listing_dataset(name: str) -> bool: + """Returns True if it's special listing dataset""" + return name.startswith(LISTING_PREFIX) + + +def is_listing_expired(created_at: datetime) -> bool: + """Checks if listing has expired based on it's creation date""" + return datetime.now(timezone.utc) > created_at + timedelta(seconds=LISTING_TTL) + + +def is_listing_subset(ds1_name: str, ds2_name: str) -> bool: + """ + Checks if one listing contains another one by comparing corresponding dataset names + """ + assert ds1_name.endswith("/") + assert ds2_name.endswith("/") + + return ds2_name.startswith(ds1_name) diff --git a/src/datachain/lib/meta_formats.py b/src/datachain/lib/meta_formats.py index f3290b8c6..321252049 100644 --- a/src/datachain/lib/meta_formats.py +++ b/src/datachain/lib/meta_formats.py @@ -54,10 +54,10 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None): try: with source_file.open() as fd: # CSV can be larger than memory if data_type == "csv": - data_string += fd.readline().decode("utf-8", "ignore").replace("\r", "") - data_string += fd.readline().decode("utf-8", "ignore").replace("\r", "") + data_string += fd.readline().replace("\r", "") + data_string += fd.readline().replace("\r", "") elif data_type == "jsonl": - data_string = fd.readline().decode("utf-8", "ignore").replace("\r", "") + data_string = fd.readline().replace("\r", "") else: data_string = fd.read() # other meta must fit into RAM except OSError as e: @@ -120,7 +120,7 @@ def read_meta( # noqa: C901 sys.stdout = captured_output try: chain = ( - DataChain.from_storage(schema_from) + DataChain.from_storage(schema_from, type="text") .limit(1) .map( # dummy column created (#1615) meta_schema=lambda file: read_schema( diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 6ab1100f7..447d6fdf9 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -16,7 +16,6 @@ get_origin, ) -import sqlalchemy as sa from pydantic import BaseModel, create_model from typing_extensions import Literal as LiteralEx @@ -341,7 +340,7 @@ def db_signals( signals = [ DEFAULT_DELIMITER.join(path) if not as_columns - else sa.Column(DEFAULT_DELIMITER.join(path), python_to_sql(_type)) + else Column(DEFAULT_DELIMITER.join(path), python_to_sql(_type)) for path, _type, has_subtree, _ in self.get_flat_tree() if not has_subtree ] @@ -415,6 +414,10 @@ def mutate(self, args_map: dict) -> "SignalSchema": # renaming existing signal del new_values[value.name] new_values[name] = self.values[value.name] + elif name in self.values: + # changing the type of existing signal, e.g File -> ImageFile + del new_values[name] + new_values[name] = args_map[name] else: # adding new signal new_values.update(sql_to_python({name: value})) diff --git a/src/datachain/utils.py b/src/datachain/utils.py index 5af9399b8..beaf09f2f 100644 --- a/src/datachain/utils.py +++ b/src/datachain/utils.py @@ -448,3 +448,8 @@ def get_datachain_executable() -> list[str]: if datachain_exec_path := os.getenv("DATACHAIN_EXEC_PATH"): return [datachain_exec_path] return [sys.executable, "-m", "datachain"] + + +def uses_glob(path: str) -> bool: + """Checks if some URI path has glob syntax in it""" + return glob.has_magic(os.path.basename(os.path.normpath(path))) diff --git a/tests/benchmarks/test_datachain.py b/tests/benchmarks/test_datachain.py index 7c130b0ed..5103fbf25 100644 --- a/tests/benchmarks/test_datachain.py +++ b/tests/benchmarks/test_datachain.py @@ -1,14 +1,13 @@ import pytest -from datachain.catalog import get_catalog from datachain.lib.dc import DataChain from datachain.lib.webdataset_laion import process_laion_meta @pytest.mark.benchmark -def test_datachain(tmp_dir, datasets, benchmark): - def run_script(uri, catalog, **kwargs): - DataChain.from_storage(uri, catalog=catalog, **kwargs).gen( +def test_datachain(tmp_dir, test_session, datasets, benchmark): + def run_script(uri, **kwargs): + DataChain.from_storage(uri, session=test_session, **kwargs).gen( emd=process_laion_meta ).map( stem=lambda file: file.get_file_stem(), @@ -16,7 +15,6 @@ def run_script(uri, catalog, **kwargs): output=str, ).save("laion_emb") - catalog = get_catalog() dataset = datasets / "laion-tiny.npz" assert dataset.is_file() - benchmark(run_script, dataset.as_uri(), catalog) + benchmark(run_script, dataset.as_uri()) diff --git a/tests/conftest.py b/tests/conftest.py index e75cee3aa..d4df20784 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -394,6 +394,10 @@ def partial_path(self): def client_config(self): return self.server.client_config + @property + def session(self) -> Session: + return Session("CTCSession", catalog=self.catalog) + cloud_types = ["s3", "gs", "azure"] diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 51c8b9895..87176458b 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -1,7 +1,7 @@ import math import os import re -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from pathlib import Path import pandas as pd @@ -14,9 +14,26 @@ from datachain.dataset import DatasetStats from datachain.lib.dc import DataChain, DataChainColumnError from datachain.lib.file import File, ImageFile +from datachain.lib.listing import ( + LISTING_TTL, + is_listing_dataset, + parse_listing_uri, +) from tests.utils import images_equal +def _get_listing_datasets(session): + return sorted( + [ + f"{ds.name}@v{ds.version}" + for ds in DataChain.datasets(session=session, include_listing=True).collect( + "dataset" + ) + if is_listing_dataset(ds.name) + ] + ) + + @pytest.mark.parametrize("anon", [True, False]) def test_catalog_anon(tmp_dir, catalog, anon): chain = DataChain.from_storage(tmp_dir.as_uri(), anon=anon) @@ -25,22 +42,140 @@ def test_catalog_anon(tmp_dir, catalog, anon): def test_from_storage(cloud_test_catalog): ctc = cloud_test_catalog - dc = DataChain.from_storage(ctc.src_uri, catalog=ctc.catalog) + dc = DataChain.from_storage(ctc.src_uri, session=ctc.session) assert dc.count() == 7 -def test_from_storage_reindex(tmp_dir, catalog): +def test_from_storage_non_recursive(cloud_test_catalog): + ctc = cloud_test_catalog + dc = DataChain.from_storage( + f"{ctc.src_uri}/dogs", session=ctc.session, recursive=False + ) + assert dc.count() == 3 + + +def test_from_storage_glob(cloud_test_catalog): + ctc = cloud_test_catalog + dc = DataChain.from_storage(f"{ctc.src_uri}/dogs*", session=ctc.session) + assert dc.count() == 4 + + +def test_from_storage_as_image(cloud_test_catalog): + ctc = cloud_test_catalog + dc = DataChain.from_storage(ctc.src_uri, session=ctc.session, type="image") + for im in dc.collect("file"): + assert isinstance(im, ImageFile) + + +def test_from_storage_reindex(tmp_dir, test_session): + tmp_dir = tmp_dir / "parquets" path = tmp_dir.as_uri() + os.mkdir(tmp_dir) pd.DataFrame({"name": ["Alice", "Bob"]}).to_parquet(tmp_dir / "test1.parquet") - assert DataChain.from_storage(path, catalog=catalog).count() == 1 + assert DataChain.from_storage(path, session=test_session).count() == 1 pd.DataFrame({"name": ["Charlie", "David"]}).to_parquet(tmp_dir / "test2.parquet") - assert DataChain.from_storage(path, catalog=catalog).count() == 1 - assert DataChain.from_storage(path, catalog=catalog, update=True).count() == 2 + assert DataChain.from_storage(path, session=test_session).count() == 1 + assert DataChain.from_storage(path, session=test_session, update=True).count() == 2 -@pytest.mark.parametrize("use_cache", [False, True]) +def test_from_storage_reindex_expired(tmp_dir, test_session): + catalog = test_session.catalog + tmp_dir = tmp_dir / "parquets" + os.mkdir(tmp_dir) + uri = tmp_dir.as_uri() + + lst_ds_name = parse_listing_uri(uri, catalog.cache, catalog.client_config)[0] + + pd.DataFrame({"name": ["Alice", "Bob"]}).to_parquet(tmp_dir / "test1.parquet") + assert DataChain.from_storage(uri, session=test_session).count() == 1 + pd.DataFrame({"name": ["Charlie", "David"]}).to_parquet(tmp_dir / "test2.parquet") + # mark dataset as expired + test_session.catalog.metastore.update_dataset_version( + test_session.catalog.get_dataset(lst_ds_name), + 1, + created_at=datetime.now(timezone.utc) - timedelta(seconds=LISTING_TTL + 20), + ) + + # listing was updated because listing dataset was expired + assert DataChain.from_storage(uri, session=test_session).count() == 2 + + +@pytest.mark.parametrize( + "cloud_type", + ["s3", "azure", "gs"], + indirect=True, +) +def test_from_storage_partials(cloud_test_catalog): + ctc = cloud_test_catalog + src_uri = ctc.src_uri + session = ctc.session + catalog = session.catalog + + def _list_dataset_name(uri: str) -> str: + return parse_listing_uri(uri, catalog.cache, catalog.client_config)[0] + + dogs_uri = f"{src_uri}/dogs" + DataChain.from_storage(dogs_uri, session=session) + assert _get_listing_datasets(session) == [ + f"{_list_dataset_name(dogs_uri)}@v1", + ] + + DataChain.from_storage(f"{src_uri}/dogs/others", session=session) + assert _get_listing_datasets(session) == [ + f"{_list_dataset_name(dogs_uri)}@v1", + ] + + DataChain.from_storage(src_uri, session=session) + assert _get_listing_datasets(session) == sorted( + [ + f"{_list_dataset_name(dogs_uri)}@v1", + f"{_list_dataset_name(src_uri)}@v1", + ] + ) + + DataChain.from_storage(f"{src_uri}/cats", session=session) + assert _get_listing_datasets(session) == sorted( + [ + f"{_list_dataset_name(dogs_uri)}@v1", + f"{_list_dataset_name(src_uri)}@v1", + ] + ) + + +@pytest.mark.parametrize( + "cloud_type", + ["s3", "azure", "gs"], + indirect=True, +) +def test_from_storage_partials_with_update(cloud_test_catalog): + ctc = cloud_test_catalog + src_uri = ctc.src_uri + session = ctc.session + catalog = session.catalog + + def _list_dataset_name(uri: str) -> str: + return parse_listing_uri(uri, catalog.cache, catalog.client_config)[0] + + uri = f"{src_uri}/cats" + DataChain.from_storage(uri, session=session) + assert _get_listing_datasets(session) == sorted( + [ + f"{_list_dataset_name(uri)}@v1", + ] + ) + + DataChain.from_storage(uri, session=session, update=True) + assert _get_listing_datasets(session) == sorted( + [ + f"{_list_dataset_name(uri)}@v1", + f"{_list_dataset_name(uri)}@v2", + ] + ) + + +@pytest.mark.parametrize("use_cache", [True, False]) def test_map_file(cloud_test_catalog, use_cache): ctc = cloud_test_catalog @@ -49,10 +184,11 @@ def new_signal(file: File) -> str: return file.name + " -> " + f.read().decode("utf-8") dc = ( - DataChain.from_storage(ctc.src_uri, catalog=ctc.catalog) + DataChain.from_storage(ctc.src_uri, session=ctc.session) .settings(cache=use_cache) .map(signal=new_signal) ) + expected = { "description -> Cats and Dogs", "cat1 -> meow", @@ -71,7 +207,7 @@ def new_signal(file: File) -> str: def test_read_file(cloud_test_catalog, use_cache): ctc = cloud_test_catalog - dc = DataChain.from_storage(ctc.src_uri, catalog=ctc.catalog) + dc = DataChain.from_storage(ctc.src_uri, session=ctc.session) for file in dc.settings(cache=use_cache).collect("file"): assert file.get_local_path() is None file.read() @@ -84,10 +220,10 @@ def test_read_file(cloud_test_catalog, use_cache): @pytest.mark.parametrize("file_type", ["", "binary", "text"]) @pytest.mark.parametrize("cloud_type", ["file"], indirect=True) def test_export_files( - tmp_dir, cloud_test_catalog, placement, use_map, use_cache, file_type + tmp_dir, cloud_test_catalog, test_session, placement, use_map, use_cache, file_type ): ctc = cloud_test_catalog - df = DataChain.from_storage(ctc.src_uri, type=file_type, catalog=ctc.catalog) + df = DataChain.from_storage(ctc.src_uri, type=file_type, session=test_session) if use_map: df.export_files(tmp_dir / "output", placement=placement, use_cache=use_cache) df.map( diff --git a/tests/func/test_feature_pickling.py b/tests/func/test_feature_pickling.py index 9451f3e8f..027c63148 100644 --- a/tests/func/test_feature_pickling.py +++ b/tests/func/test_feature_pickling.py @@ -74,8 +74,9 @@ def sort_df_for_tests(df): indirect=True, ) def test_feature_udf_parallel(cloud_test_catalog_tmpfile): - catalog = cloud_test_catalog_tmpfile.catalog - source = cloud_test_catalog_tmpfile.src_uri + ctc = cloud_test_catalog_tmpfile + catalog = ctc.catalog + source = ctc.src_uri catalog.index([source]) import tests.func.test_feature_pickling as tfp # noqa: PLW0406 @@ -84,8 +85,8 @@ def test_feature_udf_parallel(cloud_test_catalog_tmpfile): cloudpickle.register_pickle_by_value(tfp) chain = ( - DataChain.from_storage(source, type="text", catalog=catalog) - .filter(C.path.glob("*cat*")) + DataChain.from_storage(source, type="text", session=ctc.session) + .filter(C("file.path").glob("*cat*")) .settings(parallel=2) .map( message=file_to_message, @@ -107,8 +108,9 @@ def test_feature_udf_parallel(cloud_test_catalog_tmpfile): indirect=True, ) def test_feature_udf_parallel_local(cloud_test_catalog_tmpfile): - catalog = cloud_test_catalog_tmpfile.catalog - source = cloud_test_catalog_tmpfile.src_uri + ctc = cloud_test_catalog_tmpfile + catalog = ctc.catalog + source = ctc.src_uri catalog.index([source]) class FileInfoLocal(DataModel): @@ -132,8 +134,8 @@ class AIMessageLocal(DataModel): cloudpickle.register_pickle_by_value(tfp) chain = ( - DataChain.from_storage(source, type="text", catalog=catalog) - .filter(C.path.glob("*cat*")) + DataChain.from_storage(source, type="text", session=ctc.session) + .filter(C("file.path").glob("*cat*")) .settings(parallel=2) .map( message=lambda file: AIMessageLocal( @@ -164,8 +166,9 @@ class AIMessageLocal(DataModel): indirect=True, ) def test_feature_udf_parallel_local_pydantic(cloud_test_catalog_tmpfile): - catalog = cloud_test_catalog_tmpfile.catalog - source = cloud_test_catalog_tmpfile.src_uri + ctc = cloud_test_catalog_tmpfile + catalog = ctc.catalog + source = ctc.src_uri catalog.index([source]) class FileInfoLocalPydantic(BaseModel): @@ -189,8 +192,8 @@ class AIMessageLocalPydantic(BaseModel): cloudpickle.register_pickle_by_value(tfp) chain = ( - DataChain.from_storage(source, type="text", catalog=catalog) - .filter(C.path.glob("*cat*")) + DataChain.from_storage(source, type="text", session=ctc.session) + .filter(C("file.path").glob("*cat*")) .settings(parallel=2) .map( message=lambda file: AIMessageLocalPydantic( @@ -223,8 +226,10 @@ class AIMessageLocalPydantic(BaseModel): indirect=True, ) def test_feature_udf_parallel_dynamic(cloud_test_catalog_tmpfile): - catalog = cloud_test_catalog_tmpfile.catalog - source = cloud_test_catalog_tmpfile.src_uri + ctc = cloud_test_catalog_tmpfile + catalog = ctc.catalog + source = ctc.src_uri + session = ctc.session catalog.index([source]) file_info_dynamic = create_feature_model( @@ -260,8 +265,8 @@ def test_feature_udf_parallel_dynamic(cloud_test_catalog_tmpfile): cloudpickle.register_pickle_by_value(tfp) chain = ( - DataChain.from_storage(source, type="text", catalog=catalog) - .filter(C.path.glob("*cat*")) + DataChain.from_storage(source, type="text", session=session) + .filter(C("file__path").glob("*cat*")) .settings(parallel=2) .map( message=lambda file: ai_message_dynamic( diff --git a/tests/scripts/feature_class.py b/tests/scripts/feature_class.py index f0cbbb57e..d8a7f8f48 100644 --- a/tests/scripts/feature_class.py +++ b/tests/scripts/feature_class.py @@ -10,7 +10,7 @@ class Embedding(BaseModel): ds_name = "feature_class" ds = ( DataChain.from_storage("gs://dvcx-datalakes/dogs-and-cats/") - .filter(C.path.glob("*cat*.jpg")) + .filter(C("file.path").glob("*cat*.jpg")) .limit(5) .map(emd=lambda file: Embedding(value=512), output=Embedding) ) diff --git a/tests/scripts/feature_class_parallel.py b/tests/scripts/feature_class_parallel.py index b000f5523..e963561d7 100644 --- a/tests/scripts/feature_class_parallel.py +++ b/tests/scripts/feature_class_parallel.py @@ -18,7 +18,7 @@ class Embedding(BaseModel): ds_name = "feature_class" ds = ( DataChain.from_storage("gs://dvcx-datalakes/dogs-and-cats/") - .filter(C.path.glob("*cat*.jpg")) # type: ignore [attr-defined] + .filter(C("file.path").glob("*cat*.jpg")) # type: ignore [attr-defined] .limit(5) .settings(cache=True, parallel=2) .map(emd=lambda file: Embedding(value=512), output=Embedding) diff --git a/tests/scripts/feature_class_parallel_data_model.py b/tests/scripts/feature_class_parallel_data_model.py index d8f2cb1db..90f689a18 100644 --- a/tests/scripts/feature_class_parallel_data_model.py +++ b/tests/scripts/feature_class_parallel_data_model.py @@ -17,7 +17,7 @@ class Embedding(DataModel): ds_name = "feature_class" ds = ( DataChain.from_storage("gs://dvcx-datalakes/dogs-and-cats/") - .filter(C.path.glob("*cat*.jpg")) # type: ignore [attr-defined] + .filter(C("file.path").glob("*cat*.jpg")) # type: ignore [attr-defined] .limit(5) .settings(cache=True, parallel=2) .map(emd=lambda file: Embedding(value=512), output=Embedding) diff --git a/tests/test_query_e2e.py b/tests/test_query_e2e.py index 7ea1e2adc..45d63cbc8 100644 --- a/tests/test_query_e2e.py +++ b/tests/test_query_e2e.py @@ -51,6 +51,7 @@ dogs-and-cats/cat.1001.jpg """ ), + "listing": True, }, { "command": ( diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 7c50cacaf..1165581e0 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -938,7 +938,7 @@ def test_parse_tabular_partitions(tmp_dir, test_session): df.to_parquet(path, partition_cols=["first_name"]) dc = ( DataChain.from_storage(path.as_uri(), session=test_session) - .filter(C("path").glob("*first_name=Alice*")) + .filter(C("file.path").glob("*first_name=Alice*")) .parse_tabular(partitioning="hive") ) df1 = dc.select("first_name", "age", "city").to_pandas() @@ -968,7 +968,7 @@ def test_parse_tabular_unify_schema(tmp_dir, test_session): ) dc = ( DataChain.from_storage(tmp_dir.as_uri(), session=test_session) - .filter(C("path").glob("*.parquet")) + .filter(C("file.path").glob("*.parquet")) .parse_tabular() ) df = dc.select("first_name", "age", "city", "last_name", "country").to_pandas() diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 518c1d398..e6475a783 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -3,11 +3,10 @@ from typing import Optional, Union import pytest -from sqlalchemy import Column -from datachain import DataModel +from datachain import Column, DataModel from datachain.lib.convert.flatten import flatten -from datachain.lib.file import File +from datachain.lib.file import File, TextFile from datachain.lib.signal_schema import ( SetupError, SignalResolvingError, @@ -445,6 +444,24 @@ def test_slice_nested(): assert list(sliced.values.items()) == [("feature.aa", int)] +def test_mutate_rename(): + schema = SignalSchema({"name": str}) + schema = schema.mutate({"new_name": Column("name")}) + assert schema.values == {"new_name": str} + + +def test_mutate_new_signal(): + schema = SignalSchema({"name": str}) + schema = schema.mutate({"age": Column("age", Float)}) + assert schema.values == {"name": str, "age": float} + + +def test_mutate_change_type(): + schema = SignalSchema({"name": str, "age": float, "f": File}) + schema = schema.mutate({"age": int, "f": TextFile}) + assert schema.values == {"name": str, "age": int, "f": TextFile} + + @pytest.mark.parametrize( "column_type,signal_type", [ diff --git a/tests/unit/test_listing.py b/tests/unit/test_listing.py index d25da56a3..7c9e751b0 100644 --- a/tests/unit/test_listing.py +++ b/tests/unit/test_listing.py @@ -1,9 +1,17 @@ import posixpath +from datetime import datetime, timedelta, timezone import pytest from datachain.catalog import Catalog from datachain.catalog.catalog import DataSource +from datachain.lib.listing import ( + LISTING_TTL, + is_listing_dataset, + is_listing_expired, + is_listing_subset, + parse_listing_uri, +) from datachain.node import DirType, Entry, get_path from tests.utils import skip_if_not_sqlite @@ -138,3 +146,74 @@ def test_subtree(listing): def test_subdirs(listing): dirs = list(listing.get_dirs_by_parent_path("")) _match_filenames(dirs, ["dir1", "dir2"]) + + +@pytest.mark.parametrize( + "cloud_type", + ["s3", "azure", "gs"], + indirect=True, +) +def test_parse_listing_uri(cloud_test_catalog): + ctc = cloud_test_catalog + catalog = ctc.catalog + dataset_name, listing_uri, listing_path = parse_listing_uri( + f"{ctc.src_uri}/dogs", catalog.cache, catalog.client_config + ) + assert dataset_name == f"lst__{ctc.src_uri}/dogs/" + assert listing_uri == f"{ctc.src_uri}/dogs" + assert listing_path == "dogs" + + +@pytest.mark.parametrize( + "cloud_type", + ["s3", "azure", "gs"], + indirect=True, +) +def test_parse_listing_uri_with_glob(cloud_test_catalog): + ctc = cloud_test_catalog + catalog = ctc.catalog + dataset_name, listing_uri, listing_path = parse_listing_uri( + f"{ctc.src_uri}/dogs/*", catalog.cache, catalog.client_config + ) + assert dataset_name == f"lst__{ctc.src_uri}/dogs/" + assert listing_uri == f"{ctc.src_uri}/dogs" + assert listing_path == "dogs/*" + + +@pytest.mark.parametrize( + "name,is_listing", + [ + ("lst__s3://my-bucket", True), + ("lst__file:///my-folder/dir1", True), + ("s3://my-bucket", False), + ("my-dataset", False), + ], +) +def test_is_listing_dataset(name, is_listing): + assert is_listing_dataset(name) is is_listing + + +@pytest.mark.parametrize( + "date,is_expired", + [ + (datetime.now(timezone.utc), False), + (datetime.now(timezone.utc) - timedelta(seconds=LISTING_TTL + 1), True), + ], +) +def test_is_listing_expired(date, is_expired): + assert is_listing_expired(date) is is_expired + + +@pytest.mark.parametrize( + "ds1_name,ds2_name,is_subset", + [ + ("lst__s3://my-bucket/animals/", "lst__s3://my-bucket/animals/dogs/", True), + ("lst__s3://my-bucket/animals/", "lst__s3://my-bucket/animals/", True), + ("lst__s3://my-bucket/", "lst__s3://my-bucket/", True), + ("lst__s3://my-bucket/cats/", "lst__s3://my-bucket/animals/dogs/", False), + ("lst__s3://my-bucket/dogs/", "lst__s3://my-bucket/animals/", False), + ("lst__s3://my-bucket/animals/", "lst__s3://other-bucket/animals/", False), + ], +) +def test_listing_subset(ds1_name, ds2_name, is_subset): + assert is_listing_subset(ds1_name, ds2_name) is is_subset diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 12e705348..3fe25328f 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -11,6 +11,7 @@ sizeof_fmt, sql_escape_like, suffix_to_number, + uses_glob, ) DATACHAIN_TEST_PATHS = ["/file1", "file2", "/dir/file3", "dir/file4"] @@ -178,3 +179,19 @@ def test_determine_processes(parallel, settings, expected): if settings is not None: os.environ["DATACHAIN_SETTINGS_PARALLEL"] = settings assert determine_processes(parallel) == expected + + +@pytest.mark.parametrize( + "path,expected", + ( + ("/dogs", False), + ("/dogs/", False), + ("/dogs/*", True), + ("/home/user/bucket/animals/", False), + ("/home/user/bucket/animals/*", True), + ("", False), + ("*", True), + ), +) +def test_uses_glob(path, expected): + assert uses_glob(path) is expected