diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 138e2a131..3836a37ed 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1,6 +1,8 @@ import copy import os +import os.path import re +import sys from collections.abc import Iterator, Sequence from functools import wraps from typing import ( @@ -1887,21 +1889,48 @@ def to_parquet( path: Union[str, os.PathLike[str], BinaryIO], partition_cols: Optional[Sequence[str]] = None, chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE, + fs_kwargs: Optional[dict[str, Any]] = None, **kwargs, ) -> None: """Save chain to parquet file with SignalSchema metadata. Parameters: - path : Path or a file-like binary object to save the file. + path : Path or a file-like binary object to save the file. This supports + local paths as well as remote paths, such as s3:// or hf:// with fsspec. partition_cols : Column names by which to partition the dataset. chunk_size : The chunk size of results to read and convert to columnar data, to avoid running out of memory. + fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for + write, for fsspec-type URLs, such as s3:// or hf:// when + provided as the destination path. """ import pyarrow as pa import pyarrow.parquet as pq from datachain.lib.arrow import DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY + fsspec_fs = None + + if isinstance(path, str) and "://" in path: + from datachain.client.fsspec import Client + + fs_kwargs = { + **self._query.catalog.client_config, + **(fs_kwargs or {}), + } + + client = Client.get_implementation(path) + + if path.startswith("file://"): + # pyarrow does not handle file:// uris, and needs a direct path instead. + from urllib.parse import urlparse + + path = urlparse(path).path + if sys.platform == "win32": + path = os.path.normpath(path.lstrip("/")) + + fsspec_fs = client.create_fs(**fs_kwargs) + _partition_cols = list(partition_cols) if partition_cols else None signal_schema_metadata = orjson.dumps( self._effective_signals_schema.serialize() @@ -1936,12 +1965,15 @@ def to_parquet( table, root_path=path, partition_cols=_partition_cols, + filesystem=fsspec_fs, **kwargs, ) else: if first_chunk: # Write to a single parquet file. - parquet_writer = pq.ParquetWriter(path, parquet_schema, **kwargs) + parquet_writer = pq.ParquetWriter( + path, parquet_schema, filesystem=fsspec_fs, **kwargs + ) first_chunk = False assert parquet_writer @@ -1954,22 +1986,43 @@ def to_csv( self, path: Union[str, os.PathLike[str]], delimiter: str = ",", + fs_kwargs: Optional[dict[str, Any]] = None, **kwargs, ) -> None: """Save chain to a csv (comma-separated values) file. Parameters: - path : Path to save the file. + path : Path to save the file. This supports local paths as well as + remote paths, such as s3:// or hf:// with fsspec. delimiter : Delimiter to use for the resulting file. + fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for + write, for fsspec-type URLs, such as s3:// or hf:// when + provided as the destination path. """ import csv + opener = open + + if isinstance(path, str) and "://" in path: + from datachain.client.fsspec import Client + + fs_kwargs = { + **self._query.catalog.client_config, + **(fs_kwargs or {}), + } + + client = Client.get_implementation(path) + + fsspec_fs = client.create_fs(**fs_kwargs) + + opener = fsspec_fs.open + headers, _ = self._effective_signals_schema.get_headers_with_length() column_names = [".".join(filter(None, header)) for header in headers] results_iter = self.collect_flatten() - with open(path, "w", newline="") as f: + with opener(path, "w", newline="") as f: writer = csv.writer(f, delimiter=delimiter, **kwargs) writer.writerow(column_names) diff --git a/tests/conftest.py b/tests/conftest.py index 7a5ea0978..0f392954c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -558,6 +558,26 @@ def cloud_test_catalog( ) +@pytest.fixture +def cloud_test_catalog_upload(cloud_test_catalog): + """This returns a version of the cloud_test_catalog that is suitable for uploading + files, and will perform the necessary cleanup of any uploaded files.""" + from datachain.client.fsspec import Client + + src = cloud_test_catalog.src_uri + client = Client.get_implementation(src) + fsspec_fs = client.create_fs(**cloud_test_catalog.client_config) + original_paths = set(fsspec_fs.ls(src)) + + yield cloud_test_catalog + + # Cleanup any written files + new_paths = set(fsspec_fs.ls(src)) + cleanup_paths = new_paths - original_paths + for p in cleanup_paths: + fsspec_fs.rm(p, recursive=True) + + @pytest.fixture def cloud_test_catalog_tmpfile( cloud_server, diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index b256ba296..1b4ff705a 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -38,6 +38,12 @@ text_embedding, ) +DF_DATA = { + "first_name": ["Alice", "Bob", "Charlie", "David", "Eva"], + "age": [25, 30, 35, 40, 45], + "city": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"], +} + def _get_listing_datasets(session): return sorted( @@ -1366,3 +1372,47 @@ def file_info(file: File) -> DataModel: ], "file_info__path", ) + + +def test_to_from_csv_remote(cloud_test_catalog_upload): + ctc = cloud_test_catalog_upload + path = f"{ctc.src_uri}/test.csv" + + df = pd.DataFrame(DF_DATA) + dc_to = DataChain.from_pandas(df, session=ctc.session) + dc_to.to_csv(path) + + dc_from = DataChain.from_csv(path, session=ctc.session) + df1 = dc_from.select("first_name", "age", "city").to_pandas() + assert df1.equals(df) + + +@pytest.mark.parametrize("chunk_size", (1000, 2)) +@pytest.mark.parametrize("kwargs", ({}, {"compression": "gzip"})) +def test_to_from_parquet_remote(cloud_test_catalog_upload, chunk_size, kwargs): + ctc = cloud_test_catalog_upload + path = f"{ctc.src_uri}/test.parquet" + + df = pd.DataFrame(DF_DATA) + dc_to = DataChain.from_pandas(df, session=ctc.session) + dc_to.to_parquet(path, chunk_size=chunk_size, **kwargs) + + dc_from = DataChain.from_parquet(path, session=ctc.session) + df1 = dc_from.select("first_name", "age", "city").to_pandas() + + assert df1.equals(df) + + +@pytest.mark.parametrize("chunk_size", (1000, 2)) +def test_to_from_parquet_partitioned_remote(cloud_test_catalog_upload, chunk_size): + ctc = cloud_test_catalog_upload + path = f"{ctc.src_uri}/parquets" + + df = pd.DataFrame(DF_DATA) + dc_to = DataChain.from_pandas(df, session=ctc.session) + dc_to.to_parquet(path, partition_cols=["first_name"], chunk_size=chunk_size) + + dc_from = DataChain.from_parquet(path, session=ctc.session) + df1 = dc_from.select("first_name", "age", "city").to_pandas() + df1 = df1.sort_values("first_name").reset_index(drop=True) + assert df1.equals(df)