Skip to content

Commit

Permalink
fix file signal type from storage
Browse files Browse the repository at this point in the history
  • Loading branch information
ilongin committed Aug 19, 2024
1 parent 11df7a1 commit 9b006bd
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 12 deletions.
20 changes: 9 additions & 11 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,11 @@ def from_storage(
and not update
):
# we can use found listing as it contains the one from input
return ls(
cls.from_dataset(ds.name, **kwargs), # type: ignore[union-attr]
path,
recursive=recursive,
dc = cls.from_dataset(ds.name, **kwargs) # type: ignore[union-attr]
dc.signals_schema = dc.signals_schema.mutate(
{f"{object_name}": file_type}
)
return ls(dc, path, recursive=recursive)

# caching new listing to special listing dataset
(
Expand All @@ -402,17 +402,15 @@ def from_storage(
)
.gen(
list_bucket(lst_uri, **session.catalog.client_config),
output={f"{object_name}": file_type},
output={f"{object_name}": File},
)
.save(ds_name, listing=True)
)

return ls(
cls.from_dataset(ds_name, session=session, **kwargs),
path,
recursive=recursive,
object_name=object_name,
)
dc = cls.from_dataset(ds_name, session=session, **kwargs)
dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type})

return ls(dc, path, recursive=recursive, object_name=object_name)

@classmethod
def from_dataset(
Expand Down
4 changes: 4 additions & 0 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@ def mutate(self, args_map: dict) -> "SignalSchema":
# renaming existing signal
del new_values[value.name]
new_values[name] = self.values[value.name]
elif name in self.values:
# changing the type of existing signal, e.g File -> ImageFile
del new_values[name]
new_values[name] = args_map[name]
else:
# adding new signal
new_values.update(sql_to_python({name: value}))
Expand Down
9 changes: 9 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def test_from_storage(cloud_test_catalog):
assert dc.count() == 7


def test_from_storage_as_image(cloud_test_catalog):
ctc = cloud_test_catalog
dc = DataChain.from_storage(
ctc.src_uri, client_config=ctc.catalog.client_config, type="image"
)
for im in dc.collect("file"):
assert isinstance(im, ImageFile)


def test_from_storage_reindex(tmp_dir, test_session):
tmp_dir = tmp_dir / "parquets"
path = tmp_dir.as_uri()
Expand Down
20 changes: 19 additions & 1 deletion tests/unit/lib/test_signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from datachain import Column, DataModel
from datachain.lib.convert.flatten import flatten
from datachain.lib.file import File
from datachain.lib.file import File, TextFile
from datachain.lib.signal_schema import (
SetupError,
SignalResolvingError,
Expand Down Expand Up @@ -338,3 +338,21 @@ def test_slice_nested():
keys = ["feature.aa"]
sliced = SignalSchema(schema).slice(keys)
assert list(sliced.values.items()) == [("feature.aa", int)]


def test_mutate_rename():
schema = SignalSchema({"name": str})
schema = schema.mutate({"new_name": Column("name")})
assert schema.values == {"new_name": str}


def test_mutate_new_signal():
schema = SignalSchema({"name": str})
schema = schema.mutate({"age": Column("age", Float)})
assert schema.values == {"name": str, "age": float}


def test_mutate_change_type():
schema = SignalSchema({"name": str, "age": float, "f": File})
schema = schema.mutate({"age": int, "f": TextFile})
assert schema.values == {"name": str, "age": int, "f": TextFile}

0 comments on commit 9b006bd

Please sign in to comment.