diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 4f84255a4..a7e749a7c 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -11,6 +11,7 @@ BinaryIO, Callable, ClassVar, + Literal, Optional, TypeVar, Union, @@ -2418,6 +2419,7 @@ def export_files( signal="file", placement: FileExportPlacement = "fullpath", use_cache: bool = True, + link_type: Literal["copy", "symlink"] = "copy", ) -> None: """Method that exports all files from chain to some folder.""" if placement == "filename" and ( @@ -2427,7 +2429,7 @@ def export_files( raise ValueError("Files with the same name found") for file in self.collect(signal): - file.export(output, placement, use_cache) # type: ignore[union-attr] + file.export(output, placement, use_cache, link_type=link_type) # type: ignore[union-attr] def shuffle(self) -> "Self": """Shuffle the rows of the chain deterministically.""" diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 1aaa3fc91..fbba3ab14 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -1,3 +1,4 @@ +import errno import hashlib import io import json @@ -236,11 +237,26 @@ def save(self, destination: str): with open(destination, mode="wb") as f: f.write(self.read()) + def _symlink_to(self, destination: str): + if self.location: + raise OSError(errno.ENOTSUP, "Symlinking virtual file is not supported") + + if self._caching_enabled: + self.ensure_cached() + source = self.get_local_path() + assert source, "File was not cached" + elif self.source.startswith("file://"): + source = self.get_path() + else: + raise OSError(errno.EXDEV, "can't link across filesystems") + return os.symlink(source, destination) + def export( self, output: str, placement: ExportPlacement = "fullpath", use_cache: bool = True, + link_type: Literal["copy", "symlink"] = "copy", ) -> None: """Export file to new location.""" if use_cache: @@ -249,6 +265,13 @@ def export( dst_dir = os.path.dirname(dst) os.makedirs(dst_dir, exist_ok=True) + if link_type == "symlink": + try: + return self._symlink_to(dst) + except OSError as exc: + if exc.errno not in (errno.ENOTSUP, errno.EXDEV, errno.ENOSYS): + raise + self.save(dst) def _set_stream( diff --git a/tests/unit/lib/test_file.py b/tests/unit/lib/test_file.py index 87e574fd8..eaddb4f61 100644 --- a/tests/unit/lib/test_file.py +++ b/tests/unit/lib/test_file.py @@ -1,4 +1,5 @@ import json +from pathlib import Path from unittest.mock import Mock import pytest @@ -379,3 +380,17 @@ def test_get_local_path(tmp_path, catalog): assert file.get_local_path() is None file.ensure_cached() assert file.get_local_path() is not None + + +@pytest.mark.parametrize("use_cache", (True, False)) +def test_export_with_symlink(tmp_path, catalog, use_cache): + (tmp_path / "myfile.txt").write_text("some text") + + file = File(path="myfile.txt", source=f"file://{tmp_path}") + file._set_stream(catalog, use_cache) + + file.export(tmp_path / "dir", link_type="symlink", use_cache=use_cache) + assert (tmp_path / "dir" / "myfile.txt").is_symlink() + + expected_dest = file.get_local_path() if use_cache else file.get_path() + assert (tmp_path / "dir" / "myfile.txt").readlink() == Path(expected_dest)