Skip to content

Commit

Permalink
file: support exporting files as a symlink
Browse files Browse the repository at this point in the history
Closes #807.
  • Loading branch information
skshetry committed Jan 15, 2025
1 parent 57899d2 commit 6e6aa3c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BinaryIO,
Callable,
ClassVar,
Literal,
Optional,
TypeVar,
Union,
Expand Down Expand Up @@ -2418,6 +2419,7 @@ def export_files(
signal="file",
placement: FileExportPlacement = "fullpath",
use_cache: bool = True,
link_type: Literal["copy", "symlink"] = "copy",
) -> None:
"""Method that exports all files from chain to some folder."""
if placement == "filename" and (
Expand All @@ -2427,7 +2429,7 @@ def export_files(
raise ValueError("Files with the same name found")

for file in self.collect(signal):
file.export(output, placement, use_cache) # type: ignore[union-attr]
file.export(output, placement, use_cache, link_type=link_type) # type: ignore[union-attr]

def shuffle(self) -> "Self":
"""Shuffle the rows of the chain deterministically."""
Expand Down
23 changes: 23 additions & 0 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import errno
import hashlib
import io
import json
Expand Down Expand Up @@ -236,11 +237,26 @@ def save(self, destination: str):
with open(destination, mode="wb") as f:
f.write(self.read())

def _symlink_to(self, destination: str):
if self.location:
raise OSError(errno.ENOTSUP, "Symlinking virtual file is not supported")

Check warning on line 242 in src/datachain/lib/file.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/file.py#L242

Added line #L242 was not covered by tests

if self._caching_enabled:
self.ensure_cached()
source = self.get_local_path()
assert source, "File was not cached"
elif self.source.startswith("file://"):
source = self.get_path()
else:
raise OSError(errno.EXDEV, "can't link across filesystems")

Check warning on line 251 in src/datachain/lib/file.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/file.py#L251

Added line #L251 was not covered by tests
return os.symlink(source, destination)

def export(
self,
output: str,
placement: ExportPlacement = "fullpath",
use_cache: bool = True,
link_type: Literal["copy", "symlink"] = "copy",
) -> None:
"""Export file to new location."""
if use_cache:
Expand All @@ -249,6 +265,13 @@ def export(
dst_dir = os.path.dirname(dst)
os.makedirs(dst_dir, exist_ok=True)

if link_type == "symlink":
try:
return self._symlink_to(dst)
except OSError as exc:

Check warning on line 271 in src/datachain/lib/file.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/file.py#L271

Added line #L271 was not covered by tests
if exc.errno not in (errno.ENOTSUP, errno.EXDEV, errno.ENOSYS):
raise

Check warning on line 273 in src/datachain/lib/file.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/file.py#L273

Added line #L273 was not covered by tests

self.save(dst)

def _set_stream(
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/lib/test_file.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from pathlib import Path
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -379,3 +380,17 @@ def test_get_local_path(tmp_path, catalog):
assert file.get_local_path() is None
file.ensure_cached()
assert file.get_local_path() is not None


@pytest.mark.parametrize("use_cache", (True, False))
def test_export_with_symlink(tmp_path, catalog, use_cache):
(tmp_path / "myfile.txt").write_text("some text")

file = File(path="myfile.txt", source=f"file://{tmp_path}")
file._set_stream(catalog, use_cache)

file.export(tmp_path / "dir", link_type="symlink", use_cache=use_cache)
assert (tmp_path / "dir" / "myfile.txt").is_symlink()

expected_dest = file.get_local_path() if use_cache else file.get_path()
assert (tmp_path / "dir" / "myfile.txt").readlink() == Path(expected_dest)

0 comments on commit 6e6aa3c

Please sign in to comment.