diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 82ae154bd..a3c16503b 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -240,7 +240,8 @@ def do_task(self, urls): class NodeGroup: """Class for a group of nodes from the same source""" - listing: "Listing" + listing: Optional["Listing"] + client: "Client" sources: list[DataSource] # The source path within the bucket @@ -268,9 +269,7 @@ def download(self, recursive: bool = False, pbar=None) -> None: Download this node group to cache. """ if self.sources: - self.listing.client.fetch_nodes( - self.iternodes(recursive), shared_progress_bar=pbar - ) + self.client.fetch_nodes(self.iternodes(recursive), shared_progress_bar=pbar) def check_output_dataset_file( @@ -375,7 +374,7 @@ def collect_nodes_for_cp( # Collect all sources to process for node_group in node_groups: - listing: Listing = node_group.listing + listing: Optional[Listing] = node_group.listing valid_sources: list[DataSource] = [] for dsrc in node_group.sources: if dsrc.is_single_object(): @@ -383,6 +382,7 @@ def collect_nodes_for_cp( total_files += 1 valid_sources.append(dsrc) else: + assert listing node = dsrc.node if not recursive: print(f"{node.full_path} is a directory (not copied).") @@ -433,37 +433,51 @@ def instantiate_node_groups( ) output_dir = output + output_file = None if copy_to_filename: output_dir = os.path.dirname(output) if not output_dir: output_dir = "." + output_file = os.path.basename(output) # Instantiate these nodes for node_group in node_groups: if not node_group.sources: continue - listing: Listing = node_group.listing + listing: Optional[Listing] = node_group.listing source_path: str = node_group.source_path copy_dir_contents = always_copy_dir_contents or source_path.endswith("/") - instantiated_nodes = listing.collect_nodes_to_instantiate( - node_group.sources, - copy_to_filename, - recursive, - copy_dir_contents, - source_path, - node_group.is_edatachain, - node_group.is_dataset, - ) - if not virtual_only: - listing.instantiate_nodes( - instantiated_nodes, - output_dir, - total_files, - force=force, - shared_progress_bar=instantiate_progress_bar, + if not listing: + source = node_group.sources[0] + client = source.client + node = NodeWithPath(source.node, [output_file or source.node.path]) + instantiated_nodes = [node] + if not virtual_only: + node.instantiate( + client, output_dir, instantiate_progress_bar, force=force + ) + else: + instantiated_nodes = listing.collect_nodes_to_instantiate( + node_group.sources, + copy_to_filename, + recursive, + copy_dir_contents, + source_path, + node_group.is_edatachain, + node_group.is_dataset, ) + if not virtual_only: + listing.instantiate_nodes( + instantiated_nodes, + output_dir, + total_files, + force=force, + shared_progress_bar=instantiate_progress_bar, + ) + node_group.instantiated_nodes = instantiated_nodes + if instantiate_progress_bar: instantiate_progress_bar.close() @@ -592,7 +606,7 @@ def enlist_source( client_config=None, object_name="file", skip_indexing=False, - ) -> tuple["Listing", str]: + ) -> tuple[Optional["Listing"], "Client", str]: from datachain.lib.dc import DataChain from datachain.listing import Listing @@ -603,16 +617,19 @@ def enlist_source( list_ds_name, list_uri, list_path, _ = get_listing( source, self.session, update=update ) + lst = None + client = Client.get_client(list_uri, self.cache, **self.client_config) + + if list_ds_name: + lst = Listing( + self.metastore.clone(), + self.warehouse.clone(), + client, + dataset_name=list_ds_name, + object_name=object_name, + ) - lst = Listing( - self.metastore.clone(), - self.warehouse.clone(), - Client.get_client(list_uri, self.cache, **self.client_config), - dataset_name=list_ds_name, - object_name=object_name, - ) - - return lst, list_path + return lst, client, list_path def _remove_dataset_rows_and_warehouse_info( self, dataset: DatasetRecord, version: int, **kwargs @@ -635,13 +652,13 @@ def enlist_sources( ) -> Optional[list["DataSource"]]: enlisted_sources = [] for src in sources: # Opt: parallel - listing, file_path = self.enlist_source( + listing, client, file_path = self.enlist_source( src, update, client_config=client_config or self.client_config, skip_indexing=skip_indexing, ) - enlisted_sources.append((listing, file_path)) + enlisted_sources.append((listing, client, file_path)) if only_index: # sometimes we don't really need listing result (e.g on indexing process) @@ -649,10 +666,16 @@ def enlist_sources( return None dsrc_all: list[DataSource] = [] - for listing, file_path in enlisted_sources: - nodes = listing.expand_path(file_path) - dir_only = file_path.endswith("/") - dsrc_all.extend(DataSource(listing, node, dir_only) for node in nodes) + for listing, client, file_path in enlisted_sources: + if not listing: + nodes = [Node.from_file(client.get_file_info(file_path))] + dir_only = False + else: + nodes = listing.expand_path(file_path) + dir_only = file_path.endswith("/") + dsrc_all.extend( + DataSource(listing, client, node, dir_only) for node in nodes + ) return dsrc_all def enlist_sources_grouped( @@ -667,7 +690,7 @@ def enlist_sources_grouped( def _row_to_node(d: dict[str, Any]) -> Node: del d["file__source"] - return Node.from_dict(d) + return Node.from_row(d) enlisted_sources: list[tuple[bool, bool, Any]] = [] client_config = client_config or self.client_config @@ -677,7 +700,7 @@ def _row_to_node(d: dict[str, Any]) -> Node: edatachain_data = parse_edatachain_file(src) indexed_sources = [] for ds in edatachain_data: - listing, source_path = self.enlist_source( + listing, _, source_path = self.enlist_source( ds["data-source"]["uri"], update, client_config=client_config, @@ -701,6 +724,7 @@ def _row_to_node(d: dict[str, Any]) -> Node: client = self.get_client(source, **client_config) uri = client.uri dataset_name, _, _, _ = get_listing(uri, self.session) + assert dataset_name listing = Listing( self.metastore.clone(), self.warehouse.clone(), @@ -713,6 +737,7 @@ def _row_to_node(d: dict[str, Any]) -> Node: indexed_sources.append( ( listing, + client, source, [_row_to_node(r) for r in rows], ds_name, @@ -722,25 +747,28 @@ def _row_to_node(d: dict[str, Any]) -> Node: enlisted_sources.append((False, True, indexed_sources)) else: - listing, source_path = self.enlist_source( + listing, client, source_path = self.enlist_source( src, update, client_config=client_config ) - enlisted_sources.append((False, False, (listing, source_path))) + enlisted_sources.append((False, False, (listing, client, source_path))) node_groups = [] for is_datachain, is_dataset, payload in enlisted_sources: # Opt: parallel if is_dataset: for ( listing, + client, source_path, nodes, dataset_name, dataset_version, ) in payload: - dsrc = [DataSource(listing, node) for node in nodes] + assert listing + dsrc = [DataSource(listing, client, node) for node in nodes] node_groups.append( NodeGroup( listing, + client, dsrc, source_path, dataset_name=dataset_name, @@ -749,18 +777,30 @@ def _row_to_node(d: dict[str, Any]) -> Node: ) elif is_datachain: for listing, source_path, paths in payload: - dsrc = [DataSource(listing, listing.resolve_path(p)) for p in paths] + assert listing + dsrc = [ + DataSource(listing, listing.client, listing.resolve_path(p)) + for p in paths + ] node_groups.append( - NodeGroup(listing, dsrc, source_path, is_edatachain=True) + NodeGroup( + listing, + listing.client, + dsrc, + source_path, + is_edatachain=True, + ) ) else: - listing, source_path = payload - as_container = source_path.endswith("/") - dsrc = [ - DataSource(listing, n, as_container) - for n in listing.expand_path(source_path, use_glob=not no_glob) - ] - node_groups.append(NodeGroup(listing, dsrc, source_path)) + listing, client, source_path = payload + if not listing: + nodes = [Node.from_file(client.get_file_info(source_path))] + as_container = False + else: + as_container = source_path.endswith("/") + nodes = listing.expand_path(source_path, use_glob=not no_glob) + dsrc = [DataSource(listing, client, n, as_container) for n in nodes] + node_groups.append(NodeGroup(listing, client, dsrc, source_path)) return node_groups diff --git a/src/datachain/catalog/datasource.py b/src/datachain/catalog/datasource.py index 18945d4c9..145780374 100644 --- a/src/datachain/catalog/datasource.py +++ b/src/datachain/catalog/datasource.py @@ -4,21 +4,19 @@ class DataSource: - def __init__(self, listing, node, as_container=False): + def __init__(self, listing, client, node, as_container=False): self.listing = listing + self.client = client self.node = node self.as_container = ( as_container # Indicates whether a .tar file is handled as a container ) - def get_full_path(self): - return self.get_node_full_path(self.node) - def get_node_full_path(self, node): - return self.listing.client.get_full_path(node.full_path) + return self.client.get_full_path(node.full_path) def get_node_full_path_from_path(self, full_path): - return self.listing.client.get_full_path(full_path) + return self.client.get_full_path(full_path) def is_single_object(self): return self.node.dir_type == DirType.FILE or ( diff --git a/src/datachain/client/fsspec.py b/src/datachain/client/fsspec.py index b03abae12..a53a026fa 100644 --- a/src/datachain/client/fsspec.py +++ b/src/datachain/client/fsspec.py @@ -204,6 +204,10 @@ async def get_current_etag(self, file: "File") -> str: info = await self.fs._info(self.get_full_path(file.path)) return self.info_to_file(info, "").etag + def get_file_info(self, path: str) -> "File": + info = self.fs.info(self.get_full_path(path)) + return self.info_to_file(info, path) + async def get_size(self, path: str) -> int: return await self.fs._size(path) diff --git a/src/datachain/dataset.py b/src/datachain/dataset.py index b2d12b611..0bd3b6ee0 100644 --- a/src/datachain/dataset.py +++ b/src/datachain/dataset.py @@ -92,6 +92,7 @@ def dataset_name(self) -> str: return self.name list_dataset_name, _, _ = parse_listing_uri(self.name.strip("/"), None, {}) + assert list_dataset_name return list_dataset_name @classmethod diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 303b769ad..ca09a2e13 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -32,7 +32,7 @@ from datachain.lib.dataset_info import DatasetInfo from datachain.lib.file import ArrowRow, File, FileType, get_file_type from datachain.lib.file import ExportPlacement as FileExportPlacement -from datachain.lib.listing import get_listing, list_bucket, ls +from datachain.lib.listing import get_file_info, get_listing, list_bucket, ls from datachain.lib.listing_info import ListingInfo from datachain.lib.meta_formats import read_meta from datachain.lib.model_store import ModelStore @@ -438,6 +438,18 @@ def from_storage( uri, session, update=update ) + # ds_name is None if object is a file, we don't want to use cache + # or do listing in that case - just read that single object + if not list_ds_name: + dc = cls.from_values( + session=session, + settings=settings, + in_memory=in_memory, + file=[get_file_info(list_uri, cache, client_config=client_config)], + ) + dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type}) + return dc + if update or not list_ds_exists: ( cls.from_records( @@ -1634,7 +1646,7 @@ def from_values( output: OutputType = None, object_name: str = "", **fr_map, - ) -> "DataChain": + ) -> "Self": """Generate chain from list of values. Example: @@ -1647,7 +1659,7 @@ def from_values( def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type] yield from tuples - chain = DataChain.from_records( + chain = cls.from_records( DataChain.DEFAULT_FILE_RECORD, session=session, settings=settings, diff --git a/src/datachain/lib/listing.py b/src/datachain/lib/listing.py index 7955592e5..1829f99fc 100644 --- a/src/datachain/lib/listing.py +++ b/src/datachain/lib/listing.py @@ -39,6 +39,15 @@ def list_func() -> Iterator[File]: return list_func +def get_file_info(uri: str, cache, client_config=None) -> File: + """ + Wrapper to return File object by its URI + """ + client = Client.get_client(uri, cache, **(client_config or {})) # type: ignore[arg-type] + _, path = Client.parse_url(uri) + return client.get_file_info(path) + + def ls( dc: D, path: str, @@ -76,7 +85,7 @@ def _file_c(name: str) -> Column: return dc.filter(pathfunc.parent(_file_c("path")) == path.lstrip("/").rstrip("/*")) -def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]: +def parse_listing_uri(uri: str, cache, client_config) -> tuple[Optional[str], str, str]: """ Parsing uri and returns listing dataset name, listing uri and listing path """ @@ -85,7 +94,9 @@ def parse_listing_uri(uri: str, cache, client_config) -> tuple[str, str, str]: storage_uri, path = Client.parse_url(uri) telemetry.log_param("client", client.PREFIX) - if uses_glob(path) or client.fs.isfile(uri): + if not uri.endswith("/") and client.fs.isfile(uri): + return None, f'{storage_uri}/{path.lstrip("/")}', path + if uses_glob(path): lst_uri_path = posixpath.dirname(path) else: storage_uri, path = Client.parse_url(f'{uri.rstrip("/")}/') @@ -113,7 +124,7 @@ def listing_uri_from_name(dataset_name: str) -> str: def get_listing( uri: str, session: "Session", update: bool = False -) -> tuple[str, str, str, bool]: +) -> tuple[Optional[str], str, str, bool]: """Returns correct listing dataset name that must be used for saving listing operation. It takes into account existing listings and reusability of those. It also returns boolean saying if returned dataset name is reused / already @@ -131,6 +142,10 @@ def get_listing( ds_name, list_uri, list_path = parse_listing_uri(uri, cache, client_config) listing = None + # if we don't want to use cached dataset (e.g. for a single file listing) + if not ds_name: + return None, list_uri, list_path, False + listings = [ ls for ls in catalog.listings() if not ls.is_expired and ls.contains(ds_name) ] diff --git a/src/datachain/listing.py b/src/datachain/listing.py index b68fe493c..f2d1b0e36 100644 --- a/src/datachain/listing.py +++ b/src/datachain/listing.py @@ -157,11 +157,7 @@ def instantiate_nodes( counter = 0 for node in all_nodes: - dst = os.path.join(output, *node.path) - dst_dir = os.path.dirname(dst) - os.makedirs(dst_dir, exist_ok=True) - file = node.n.to_file(self.client.uri) - self.client.instantiate_object(file, dst, progress_bar, force) + node.instantiate(self.client, output, progress_bar, force=force) counter += 1 if counter > 1000: progress_bar.update(counter) diff --git a/src/datachain/node.py b/src/datachain/node.py index 0f95e870a..60c437ed9 100644 --- a/src/datachain/node.py +++ b/src/datachain/node.py @@ -1,3 +1,4 @@ +import os from datetime import datetime from typing import TYPE_CHECKING, Any, Optional @@ -10,6 +11,8 @@ if TYPE_CHECKING: from typing_extensions import Self + from datachain.client import Client + class DirType: FILE = 0 @@ -114,7 +117,21 @@ def to_file(self, source: Optional[StorageURI] = None) -> File: ) @classmethod - def from_dict(cls, d: dict[str, Any], file_prefix: str = "file") -> "Self": + def from_file(cls, f: File) -> "Self": + return cls( + source=StorageURI(f.source), + path=f.path, + etag=f.etag, + is_latest=f.is_latest, + size=f.size, + last_modified=f.last_modified, + version=f.version, + location=str(f.location) if f.location else None, + dir_type=DirType.FILE, + ) + + @classmethod + def from_row(cls, d: dict[str, Any], file_prefix: str = "file") -> "Self": def _dval(field_name: str): return d.get(f"{file_prefix}__{field_name}") @@ -174,6 +191,15 @@ def full_path(self) -> str: path += "/" return path + def instantiate( + self, client: "Client", output: str, progress_bar, *, force: bool = False + ): + dst = os.path.join(output, *self.path) + dst_dir = os.path.dirname(dst) + os.makedirs(dst_dir, exist_ok=True) + file = self.n.to_file(client.uri) + client.instantiate_object(file, dst, progress_bar, force) + TIME_FMT = "%Y-%m-%d %H:%M" diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index fb353a911..0d2798b89 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -462,7 +462,7 @@ def test_cp_single_file(cloud_test_catalog, no_glob): @pytest.mark.parametrize("tree", [{"foo": "original"}], indirect=True) -def test_storage_mutation(cloud_test_catalog): +def test_cp_file_storage_mutation(cloud_test_catalog): working_dir = cloud_test_catalog.working_dir catalog = cloud_test_catalog.catalog src_path = f"{cloud_test_catalog.src_uri}/foo" @@ -476,15 +476,15 @@ def test_storage_mutation(cloud_test_catalog): dest = working_dir / "data2" dest.mkdir() catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True) - assert tree_from_path(dest) == {"local": "original"} + assert tree_from_path(dest) == {"local": "modified"} - # Since the old version cannot be found in storage or cache, it's an error. + # For a file we access it directly, we don't take the entry from listing + # so we don't check the previous etag with the new modified one catalog.cache.clear() dest = working_dir / "data3" dest.mkdir() - with pytest.raises(FileNotFoundError): - catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True) - assert tree_from_path(dest) == {} + catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True) + assert tree_from_path(dest) == {"local": "modified"} catalog.index([src_path], update=True) dest = working_dir / "data4" @@ -493,6 +493,43 @@ def test_storage_mutation(cloud_test_catalog): assert tree_from_path(dest) == {"local": "modified"} +@pytest.mark.parametrize("tree", [{"foo-dir": "original"}], indirect=True) +def test_cp_dir_storage_mutation(cloud_test_catalog): + working_dir = cloud_test_catalog.working_dir + catalog = cloud_test_catalog.catalog + src_path = f"{cloud_test_catalog.src_uri}/" + + dest = working_dir / "data1" + dest.mkdir() + catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True, recursive=True) + assert tree_from_path(dest) == {"local": {"foo-dir": "original"}} + + (cloud_test_catalog.src / "foo-dir").write_text("modified") + dest = working_dir / "data2" + dest.mkdir() + catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True, recursive=True) + assert tree_from_path(dest) == {"local": {"foo-dir": "original"}} + + # For a dir we access files through listing + # so it finds a etag for the origin file, but it's now not in cache + it + # is modified on the local storage, so we can't find the file referenced + # by the listing anymore + catalog.cache.clear() + dest = working_dir / "data3" + dest.mkdir() + with pytest.raises(FileNotFoundError): + catalog.cp( + [src_path], str(dest / "local"), no_edatachain_file=True, recursive=True + ) + assert tree_from_path(dest) == {"local": {}} + + catalog.index([src_path], update=True) + dest = working_dir / "data4" + dest.mkdir() + catalog.cp([src_path], str(dest / "local"), no_edatachain_file=True, recursive=True) + assert tree_from_path(dest) == {"local": {"foo-dir": "modified"}} + + def test_cp_edatachain_file_options(cloud_test_catalog): working_dir = cloud_test_catalog.working_dir catalog = cloud_test_catalog.catalog @@ -734,6 +771,42 @@ def test_ls_glob(cloud_test_catalog): ) == [("dog1", ["dog1"]), ("dog2", ["dog2"]), ("dog3", ["dog3"])] +def test_ls_file(cloud_test_catalog): + src_uri = cloud_test_catalog.src_uri + catalog = cloud_test_catalog.catalog + + assert sorted( + (source.node.name, [r[0] for r in results]) + for source, results in catalog.ls([f"{src_uri}/dogs/dog1"], fields=["name"]) + ) == [("dog1", ["dog1"])] + + +def test_ls_dir_same_name_as_file(cloud_test_catalog, cloud_type): + src_uri = cloud_test_catalog.src_uri + catalog = cloud_test_catalog.catalog + + path = f"{src_uri}/dogs/dog1" + + # check that file exists + assert sorted( + (source.node.name, [r[0] for r in results]) + for source, results in catalog.ls([path], fields=["name"]) + ) == [("dog1", ["dog1"])] + + if cloud_type == "file": + # should be fixed upstream in fsspec + # boils down to https://github.com/fsspec/filesystem_spec/pull/1567#issuecomment-2563160414 + # fsspec removes the trailing slash and returns a file, that's why we are + # are not getting an error here + assert sorted( + (source.node.name, [r[0] for r in results]) + for source, results in catalog.ls([f"{path}/"], fields=["name"]) + ) == [("", ["."])] + else: + with pytest.raises(FileNotFoundError): + next(catalog.ls([f"{path}/"], fields=["name"])) + + def test_ls_prefix_not_found(cloud_test_catalog): src_uri = cloud_test_catalog.src_uri catalog = cloud_test_catalog.catalog @@ -891,9 +964,8 @@ def test_enlist_source_handles_file(cloud_test_catalog): src_path = f"{src_uri}/dogs/dog1" catalog.enlist_source(src_path) - stats = listing_stats(src_path, catalog) - assert stats.num_objects == len(DEFAULT_TREE["dogs"]) - assert stats.size == 15 + with pytest.raises(DatasetNotFoundError): + listing_stats(src_path, catalog) @pytest.mark.parametrize("from_cli", [False, True]) diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 79824598f..f4f504a4d 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -136,7 +136,9 @@ def test_from_storage_partials(cloud_test_catalog): catalog = session.catalog def _list_dataset_name(uri: str) -> str: - return parse_listing_uri(uri, catalog.cache, catalog.client_config)[0] + name = parse_listing_uri(uri, catalog.cache, catalog.client_config)[0] + assert name + return name dogs_uri = f"{src_uri}/dogs" DataChain.from_storage(dogs_uri, session=session) @@ -178,7 +180,9 @@ def test_from_storage_partials_with_update(cloud_test_catalog): catalog = session.catalog def _list_dataset_name(uri: str) -> str: - return parse_listing_uri(uri, catalog.cache, catalog.client_config)[0] + name = parse_listing_uri(uri, catalog.cache, catalog.client_config)[0] + assert name + return name uri = f"{src_uri}/cats" DataChain.from_storage(uri, session=session) diff --git a/tests/unit/test_listing.py b/tests/unit/test_listing.py index b7407904b..3e5800d75 100644 --- a/tests/unit/test_listing.py +++ b/tests/unit/test_listing.py @@ -161,7 +161,7 @@ def test_list_dir(listing): def test_list_file(listing): file = listing.resolve_path("dir1/dataset.csv") - src = DataSource(listing, file) + src = DataSource(listing, listing.client, file) results = list(src.ls(["sys__id", "name", "dir_type"])) assert {r[1] for r in results} == {"dataset.csv"} assert results[0][0] == file.sys__id