Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resolve files #313

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import json
import logging
import os
import posixpath
from abc import ABC, abstractmethod
Expand All @@ -15,6 +16,9 @@
from PIL import Image
from pydantic import Field, field_validator

if TYPE_CHECKING:
from typing_extensions import Self

from datachain.cache import UniqueId
from datachain.client.fileslice import FileSlice
from datachain.lib.data_model import DataModel
Expand All @@ -25,6 +29,8 @@
if TYPE_CHECKING:
from datachain.catalog import Catalog

logger = logging.getLogger("datachain")

# how to create file path when exporting
ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]

Expand Down Expand Up @@ -313,6 +319,70 @@ def get_fs(self):
"""Returns `fsspec` filesystem for the file."""
return self._catalog.get_client(self.source).fs

def resolve(self) -> "Self":
"""
Resolve a File object by checking its existence and updating its metadata.

Returns:
File: The resolved File object with updated metadata.
"""
if self._catalog is None:
raise RuntimeError("Cannot resolve file: catalog is not set")

try:
client = self._catalog.get_client(self.source)
except NotImplementedError as e:
raise RuntimeError(
f"Unsupported protocol for file source: {self.source}"
) from e

try:
info = client.fs.info(client.get_full_path(self.path))
converted_info = client.info_to_file(info, self.source)
return type(self)(
path=self.path,
source=self.source,
size=converted_info.size,
etag=converted_info.etag,
version=converted_info.version,
is_latest=converted_info.is_latest,
last_modified=converted_info.last_modified,
location=self.location,
)
except (FileNotFoundError, PermissionError, OSError) as e:
logger.warning("File system error when resolving %s: %s", self.path, str(e))

return type(self)(
path=self.path,
source=self.source,
size=0,
etag="",
version="",
is_latest=True,
last_modified=TIME_ZERO,
location=self.location,
)


def resolve(file: File) -> File:
"""
Resolve a File object by checking its existence and updating its metadata.

This function is a wrapper around the File.resolve() method, designed to be
used as a mapper in DataChain operations.

Args:
file (File): The File object to resolve.

Returns:
File: The resolved File object with updated metadata.

Raises:
RuntimeError: If the file's catalog is not set or if
the file source protocol is unsupported.
"""
return file.resolve()


class TextFile(File):
"""`DataModel` for reading text files."""
Expand Down
58 changes: 57 additions & 1 deletion tests/unit/lib/test_file.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json
from unittest.mock import Mock

import pytest
from fsspec.implementations.local import LocalFileSystem
from PIL import Image

from datachain import DataChain
from datachain.cache import UniqueId
from datachain.catalog import Catalog
from datachain.lib.file import File, ImageFile, TextFile
from datachain.lib.file import File, ImageFile, TextFile, resolve
from datachain.utils import TIME_ZERO


def create_file(source: str):
Expand Down Expand Up @@ -319,3 +322,56 @@ def test_read_text(tmp_path, catalog):
file = File(path=file_name, source=f"file://{tmp_path}")
file._set_stream(catalog, True)
assert file.read_text() == data


def test_resolve_file(cloud_test_catalog):
ctc = cloud_test_catalog

dc = DataChain.from_storage(ctc.src_uri, session=ctc.session)
for orig_file in dc.collect("file"):
resolved_file = File(source=orig_file.source, path=orig_file.path)
resolved_file._catalog = ctc.catalog
assert orig_file == resolved_file.resolve()


def test_resolve_file_no_exist(cloud_test_catalog):
ctc = cloud_test_catalog

non_existent_file = File(source=ctc.src_uri, path="non_existent_file.txt")
non_existent_file._catalog = ctc.catalog
resolved_non_existent = non_existent_file.resolve()
assert resolved_non_existent.size == 0
assert resolved_non_existent.etag == ""
assert resolved_non_existent.last_modified == TIME_ZERO


def test_resolve_unsupported_protocol():
mock_catalog = Mock()
mock_catalog.get_client.side_effect = NotImplementedError("Unsupported protocol")

file = File(source="unsupported://example.com", path="test.txt")
file._catalog = mock_catalog

with pytest.raises(RuntimeError) as exc_info:
file.resolve()

assert (
str(exc_info.value)
== "Unsupported protocol for file source: unsupported://example.com"
)


def test_file_resolve_no_catalog():
file = File(path="test.txt", source="s3://mybucket")
with pytest.raises(RuntimeError, match="Cannot resolve file: catalog is not set"):
file.resolve()


def test_resolve_function():
mock_file = Mock(spec=File)
mock_file.resolve.return_value = "resolved_file"

result = resolve(mock_file)

assert result == "resolved_file"
mock_file.resolve.assert_called_once()
Loading