Skip to content

Commit

Permalink
Refactor DataChain.from_storage() to use new listing generator (#294)
Browse files Browse the repository at this point in the history
* added generator to list a bucket

* added test for async to sync generator and fixing listing

* fix tests

* fix comments

* first version of from_storage without deprecated listing

* first version of from_storage without deprecated listing

* fixing tests and removing prints, refactoring

* fix listing generator output type

* fix linter

* fix docs

* fixing test

* fix list bucket args

* refactoring listing static methods

* fixing non recursive queries

* refactoring

* fixing listing generator tests

* added partial test

* using ctc in test session

* moved listing functions to separated file

* added listing unit tests

* fixing json

* fixing examples

* fix file signal type from storage

* fixing example

* refactoring ls function

* added more tests and fixed comments

* fixing test

* fix test name

* fixing windows tests

* returning to all tests

* removed constants from dc.py

* added ticket number

* couple of fixes from PR review

* added new method is_dataset_listing and assertions

* refactoring listing code

* added session on cloud test catalog and refactoring tests

* added uses glob util

* extracted partial with update to separate test

* returning Column from db_signals method

* import directly from datachain

* changed boolean functions with prefix is_

* removed kwargs from from_storage

* removed kwargs from datasets method

* refactoring parsing listing dataset name

* Update src/datachain/lib/file.py

Co-authored-by: Ronan Lamy <[email protected]>

* removed client config

* removed kwargs from from_records

* fixing comment

* fixing new test

* fixing listing

---------

Co-authored-by: Ronan Lamy <[email protected]>
Co-authored-by: ivan <[email protected]>
  • Loading branch information
3 people authored Sep 2, 2024
1 parent 12ddf7b commit bd52a96
Show file tree
Hide file tree
Showing 22 changed files with 501 additions and 113 deletions.
2 changes: 1 addition & 1 deletion examples/get_started/udfs/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def path_len_benchmark(path):

# Run in chain
DataChain.from_storage(
path="gs://datachain-demo/dogs-and-cats/",
"gs://datachain-demo/dogs-and-cats/",
).settings(parallel=-1).map(
path_len_benchmark,
params=["file.path"],
Expand Down
2 changes: 1 addition & 1 deletion examples/get_started/udfs/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def path_len(path):
if __name__ == "__main__":
# Run in chain
DataChain.from_storage(
path="gs://datachain-demo/dogs-and-cats/",
uri="gs://datachain-demo/dogs-and-cats/",
).map(
path_len,
params=["file.path"],
Expand Down
7 changes: 4 additions & 3 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,17 +1422,18 @@ def get_dataset_dependencies(

return direct_dependencies

def ls_datasets(self) -> Iterator[DatasetRecord]:
def ls_datasets(self, include_listing: bool = False) -> Iterator[DatasetRecord]:
datasets = self.metastore.list_datasets()
for d in datasets:
if not d.is_bucket_listing:
if not d.is_bucket_listing or include_listing:
yield d

def list_datasets_versions(
self,
include_listing: bool = False,
) -> Iterator[tuple[DatasetRecord, "DatasetVersion", Optional["Job"]]]:
"""Iterate over all dataset versions with related jobs."""
datasets = list(self.ls_datasets())
datasets = list(self.ls_datasets(include_listing=include_listing))

# preselect dataset versions jobs from db to avoid multiple queries
jobs_ids: set[str] = {
Expand Down
7 changes: 6 additions & 1 deletion src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

DATASET_PREFIX = "ds://"
QUERY_DATASET_PREFIX = "ds_query_"
LISTING_PREFIX = "lst__"


def parse_dataset_uri(uri: str) -> tuple[str, Optional[int]]:
Expand Down Expand Up @@ -443,7 +444,11 @@ def is_bucket_listing(self) -> bool:
For bucket listing we implicitly create underlying dataset to hold data. This
method is checking if this is one of those datasets.
"""
return Client.is_data_source_uri(self.name)
# TODO refactor and maybe remove method in
# https://github.com/iterative/datachain/issues/318
return Client.is_data_source_uri(self.name) or self.name.startswith(
LISTING_PREFIX
)

@property
def versions_values(self) -> list[int]:
Expand Down
102 changes: 77 additions & 25 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
from datachain.lib.dataset_info import DatasetInfo
from datachain.lib.file import ExportPlacement as FileExportPlacement
from datachain.lib.file import File, IndexedFile, get_file
from datachain.lib.file import File, IndexedFile, get_file_type
from datachain.lib.listing import (
is_listing_dataset,
is_listing_expired,
is_listing_subset,
list_bucket,
ls,
parse_listing_uri,
)
from datachain.lib.meta_formats import read_meta, read_schema
from datachain.lib.model_store import ModelStore
from datachain.lib.settings import Settings
Expand Down Expand Up @@ -311,7 +319,7 @@ def add_schema(self, signals_schema: SignalSchema) -> "Self": # noqa: D102
@classmethod
def from_storage(
cls,
path,
uri,
*,
type: Literal["binary", "text", "image"] = "binary",
session: Optional[Session] = None,
Expand All @@ -320,41 +328,79 @@ def from_storage(
recursive: Optional[bool] = True,
object_name: str = "file",
update: bool = False,
**kwargs,
anon: bool = False,
) -> "Self":
"""Get data from a storage as a list of file with all file attributes.
It returns the chain itself as usual.
Parameters:
path : storage URI with directory. URI must start with storage prefix such
uri : storage URI with directory. URI must start with storage prefix such
as `s3://`, `gs://`, `az://` or "file:///"
type : read file as "binary", "text", or "image" data. Default is "binary".
recursive : search recursively for the given path.
object_name : Created object column name.
update : force storage reindexing. Default is False.
anon : If True, we will treat cloud bucket as public one
Example:
```py
chain = DataChain.from_storage("s3://my-bucket/my-dir")
```
"""
func = get_file(type)
return (
cls(
path,
session=session,
settings=settings,
recursive=recursive,
update=update,
in_memory=in_memory,
**kwargs,
)
.map(**{object_name: func})
.select(object_name)
file_type = get_file_type(type)

if anon:
client_config = {"anon": True}
else:
client_config = None

session = Session.get(session, client_config=client_config, in_memory=in_memory)

list_dataset_name, list_uri, list_path = parse_listing_uri(
uri, session.catalog.cache, session.catalog.client_config
)
need_listing = True

for ds in cls.datasets(
session=session, in_memory=in_memory, include_listing=True
).collect("dataset"):
if (
not is_listing_expired(ds.created_at) # type: ignore[union-attr]
and is_listing_dataset(ds.name) # type: ignore[union-attr]
and is_listing_subset(ds.name, list_dataset_name) # type: ignore[union-attr]
and not update
):
need_listing = False
list_dataset_name = ds.name # type: ignore[union-attr]

if need_listing:
# caching new listing to special listing dataset
(
cls.from_records(
DataChain.DEFAULT_FILE_RECORD,
session=session,
settings=settings,
in_memory=in_memory,
)
.gen(
list_bucket(list_uri, client_config=session.catalog.client_config),
output={f"{object_name}": File},
)
.save(list_dataset_name, listing=True)
)

dc = cls.from_dataset(list_dataset_name, session=session)
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})

return ls(dc, list_path, recursive=recursive, object_name=object_name)

@classmethod
def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
def from_dataset(
cls,
name: str,
version: Optional[int] = None,
session: Optional[Session] = None,
) -> "DataChain":
"""Get data from a saved Dataset. It returns the chain itself.
Parameters:
Expand All @@ -366,7 +412,7 @@ def from_dataset(cls, name: str, version: Optional[int] = None) -> "DataChain":
chain = DataChain.from_dataset("my_cats")
```
"""
return DataChain(name=name, version=version)
return DataChain(name=name, version=version, session=session)

@classmethod
def from_json(
Expand Down Expand Up @@ -419,7 +465,7 @@ def jmespath_to_name(s: str):
object_name = jmespath_to_name(jmespath)
if not object_name:
object_name = meta_type
chain = DataChain.from_storage(path=path, type=type, **kwargs)
chain = DataChain.from_storage(uri=path, type=type, **kwargs)
signal_dict = {
object_name: read_meta(
schema_from=schema_from,
Expand Down Expand Up @@ -479,7 +525,7 @@ def jmespath_to_name(s: str):
object_name = jmespath_to_name(jmespath)
if not object_name:
object_name = meta_type
chain = DataChain.from_storage(path=path, type=type, **kwargs)
chain = DataChain.from_storage(uri=path, type=type, **kwargs)
signal_dict = {
object_name: read_meta(
schema_from=schema_from,
Expand All @@ -500,6 +546,7 @@ def datasets(
settings: Optional[dict] = None,
in_memory: bool = False,
object_name: str = "dataset",
include_listing: bool = False,
) -> "DataChain":
"""Generate chain with list of registered datasets.
Expand All @@ -517,7 +564,9 @@ def datasets(

datasets = [
DatasetInfo.from_models(d, v, j)
for d, v, j in catalog.list_datasets_versions()
for d, v, j in catalog.list_datasets_versions(
include_listing=include_listing
)
]

return cls.from_values(
Expand Down Expand Up @@ -570,7 +619,7 @@ def print_jsonl_schema( # type: ignore[override]
)

def save( # type: ignore[override]
self, name: Optional[str] = None, version: Optional[int] = None
self, name: Optional[str] = None, version: Optional[int] = None, **kwargs
) -> "Self":
"""Save to a Dataset. It returns the chain itself.
Expand All @@ -580,7 +629,7 @@ def save( # type: ignore[override]
version : version of a dataset. Default - the last version that exist.
"""
schema = self.signals_schema.clone_without_sys_signals().serialize()
return super().save(name=name, version=version, feature_schema=schema)
return super().save(name=name, version=version, feature_schema=schema, **kwargs)

def apply(self, func, *args, **kwargs):
"""Apply any function to the chain.
Expand Down Expand Up @@ -1665,7 +1714,10 @@ def from_records(

if schema:
signal_schema = SignalSchema(schema)
columns = signal_schema.db_signals(as_columns=True) # type: ignore[assignment]
columns = [
sqlalchemy.Column(c.name, c.type) # type: ignore[union-attr]
for c in signal_schema.db_signals(as_columns=True) # type: ignore[assignment]
]
else:
columns = [
sqlalchemy.Column(name, typ)
Expand Down
43 changes: 10 additions & 33 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,39 +349,6 @@ def save(self, destination: str):
self.read().save(destination)


def get_file(type_: Literal["binary", "text", "image"] = "binary"):
file: type[File] = File
if type_ == "text":
file = TextFile
elif type_ == "image":
file = ImageFile # type: ignore[assignment]

def get_file_type(
source: str,
path: str,
size: int,
version: str,
etag: str,
is_latest: bool,
last_modified: datetime,
location: Optional[Union[dict, list[dict]]],
vtype: str,
) -> file: # type: ignore[valid-type]
return file(
source=source,
path=path,
size=size,
version=version,
etag=etag,
is_latest=is_latest,
last_modified=last_modified,
location=location,
vtype=vtype,
)

return get_file_type


class IndexedFile(DataModel):
"""Metadata indexed from tabular files.
Expand All @@ -390,3 +357,13 @@ class IndexedFile(DataModel):

file: File
index: int


def get_file_type(type_: Literal["binary", "text", "image"] = "binary") -> type[File]:
file: type[File] = File
if type_ == "text":
file = TextFile
elif type_ == "image":
file = ImageFile # type: ignore[assignment]

return file
Loading

0 comments on commit bd52a96

Please sign in to comment.