Skip to content

Commit

Permalink
add resolve
Browse files Browse the repository at this point in the history
  • Loading branch information
EdwardLi-coder committed Sep 7, 2024
1 parent a0ce094 commit 50dc285
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 2 deletions.
71 changes: 70 additions & 1 deletion src/datachain/lib/file.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import io
import json
import logging
import os
import posixpath
from abc import ABC, abstractmethod
from contextlib import contextmanager
from datetime import datetime
from datetime import datetime, timezone
from io import BytesIO
from pathlib import Path, PurePosixPath
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
Expand All @@ -25,6 +26,8 @@
if TYPE_CHECKING:
from datachain.catalog import Catalog

logger = logging.getLogger("datachain")

# how to create file path when exporting
ExportPlacement = Literal["filename", "etag", "fullpath", "checksum"]

Expand Down Expand Up @@ -315,6 +318,72 @@ def get_fs(self):
"""Returns `fsspec` filesystem for the file."""
return self._catalog.get_client(self.source).fs

def resolve(self) -> "File":
"""
Resolve a File object by checking its existence and updating its metadata.
Returns:
File: The resolved File object with updated metadata.
"""
if self._catalog is None:
raise RuntimeError("Cannot resolve file: catalog is not set")

try:
client = self._catalog.get_client(self.source)
except NotImplementedError as e:
raise RuntimeError(
f"Unsupported protocol for file source: {self.source}"
) from e

try:
info = client.fs.info(client.get_full_path(self.path))
converted_info = client.convert_info(info, self.source)
return type(self)(
path=self.path,
source=self.source,
size=getattr(converted_info, "size", 0),
etag=getattr(converted_info, "etag", ""),
version=getattr(converted_info, "version", None) or "",
is_latest=getattr(converted_info, "is_latest", True),
last_modified=getattr(
converted_info, "last_modified", datetime.now(timezone.utc)
),
location=self.location,
)
except (FileNotFoundError, PermissionError, OSError) as e:
logger.warning("File system error when resolving %s: %s", self.path, str(e))

return type(self)(
path=self.path,
source=self.source,
size=0,
etag="",
version="",
is_latest=True,
last_modified=TIME_ZERO,
location=self.location,
)


def resolve(file: File) -> File:
"""
Resolve a File object by checking its existence and updating its metadata.
This function is a wrapper around the File.resolve() method, designed to be
used as a mapper in DataChain operations.
Args:
file (File): The File object to resolve.
Returns:
File: The resolved File object with updated metadata.
Raises:
RuntimeError: If the file's catalog is not set or if
the file source protocol is unsupported.
"""
return file.resolve()


class TextFile(File):
"""`DataModel` for reading text files."""
Expand Down
58 changes: 57 additions & 1 deletion tests/unit/lib/test_file.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json
from unittest.mock import Mock

import pytest
from fsspec.implementations.local import LocalFileSystem
from PIL import Image

from datachain import DataChain
from datachain.cache import UniqueId
from datachain.catalog import Catalog
from datachain.lib.file import File, ImageFile, TextFile
from datachain.lib.file import File, ImageFile, TextFile, resolve
from datachain.utils import TIME_ZERO


def create_file(source: str):
Expand Down Expand Up @@ -323,3 +326,56 @@ def test_read_text(tmp_path, catalog):
file = File(path=file_name, source=f"file://{tmp_path}")
file._set_stream(catalog, True)
assert file.read_text() == data


def test_resolve_file(cloud_test_catalog):
ctc = cloud_test_catalog

dc = DataChain.from_storage(ctc.src_uri, session=ctc.session)
for orig_file in dc.collect("file"):
resolved_file = File(source=orig_file.source, path=orig_file.path)
resolved_file._catalog = ctc.catalog
assert orig_file == resolved_file.resolve()


def test_resolve_file_no_exist(cloud_test_catalog):
ctc = cloud_test_catalog

non_existent_file = File(source=ctc.src_uri, path="non_existent_file.txt")
non_existent_file._catalog = ctc.catalog
resolved_non_existent = non_existent_file.resolve()
assert resolved_non_existent.size == 0
assert resolved_non_existent.etag == ""
assert resolved_non_existent.last_modified == TIME_ZERO


def test_resolve_unsupported_protocol():
mock_catalog = Mock()
mock_catalog.get_client.side_effect = NotImplementedError("Unsupported protocol")

file = File(source="unsupported://example.com", path="test.txt")
file._catalog = mock_catalog

with pytest.raises(RuntimeError) as exc_info:
file.resolve()

assert (
str(exc_info.value)
== "Unsupported protocol for file source: unsupported://example.com"
)


def test_file_resolve_no_catalog():
file = File(path="test.txt", source="s3://mybucket")
with pytest.raises(RuntimeError, match="Cannot resolve file: catalog is not set"):
file.resolve()


def test_resolve_function():
mock_file = Mock(spec=File)
mock_file.resolve.return_value = "resolved_file"

result = resolve(mock_file)

assert result == "resolved_file"
mock_file.resolve.assert_called_once()

0 comments on commit 50dc285

Please sign in to comment.