From 9b006bd6000f95a58abd5a95602fc9e2545731ba Mon Sep 17 00:00:00 2001 From: Ivan Longin Date: Mon, 19 Aug 2024 11:51:16 +0200 Subject: [PATCH] fix file signal type from storage --- src/datachain/lib/dc.py | 20 +++++++++----------- src/datachain/lib/signal_schema.py | 4 ++++ tests/func/test_datachain.py | 9 +++++++++ tests/unit/lib/test_signal_schema.py | 20 +++++++++++++++++++- 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index b2ddd6859..4fcf9b400 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -386,11 +386,11 @@ def from_storage( and not update ): # we can use found listing as it contains the one from input - return ls( - cls.from_dataset(ds.name, **kwargs), # type: ignore[union-attr] - path, - recursive=recursive, + dc = cls.from_dataset(ds.name, **kwargs) # type: ignore[union-attr] + dc.signals_schema = dc.signals_schema.mutate( + {f"{object_name}": file_type} ) + return ls(dc, path, recursive=recursive) # caching new listing to special listing dataset ( @@ -402,17 +402,15 @@ def from_storage( ) .gen( list_bucket(lst_uri, **session.catalog.client_config), - output={f"{object_name}": file_type}, + output={f"{object_name}": File}, ) .save(ds_name, listing=True) ) - return ls( - cls.from_dataset(ds_name, session=session, **kwargs), - path, - recursive=recursive, - object_name=object_name, - ) + dc = cls.from_dataset(ds_name, session=session, **kwargs) + dc.signals_schema = dc.signals_schema.mutate({f"{object_name}": file_type}) + + return ls(dc, path, recursive=recursive, object_name=object_name) @classmethod def from_dataset( diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 338afd06d..8cc5d48ee 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -306,6 +306,10 @@ def mutate(self, args_map: dict) -> "SignalSchema": # renaming existing signal del new_values[value.name] new_values[name] = self.values[value.name] + elif name in self.values: + # changing the type of existing signal, e.g File -> ImageFile + del new_values[name] + new_values[name] = args_map[name] else: # adding new signal new_values.update(sql_to_python({name: value})) diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 0a65bf7aa..b7e0500f2 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -29,6 +29,15 @@ def test_from_storage(cloud_test_catalog): assert dc.count() == 7 +def test_from_storage_as_image(cloud_test_catalog): + ctc = cloud_test_catalog + dc = DataChain.from_storage( + ctc.src_uri, client_config=ctc.catalog.client_config, type="image" + ) + for im in dc.collect("file"): + assert isinstance(im, ImageFile) + + def test_from_storage_reindex(tmp_dir, test_session): tmp_dir = tmp_dir / "parquets" path = tmp_dir.as_uri() diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 13f2b8da6..09d257029 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -5,7 +5,7 @@ from datachain import Column, DataModel from datachain.lib.convert.flatten import flatten -from datachain.lib.file import File +from datachain.lib.file import File, TextFile from datachain.lib.signal_schema import ( SetupError, SignalResolvingError, @@ -338,3 +338,21 @@ def test_slice_nested(): keys = ["feature.aa"] sliced = SignalSchema(schema).slice(keys) assert list(sliced.values.items()) == [("feature.aa", int)] + + +def test_mutate_rename(): + schema = SignalSchema({"name": str}) + schema = schema.mutate({"new_name": Column("name")}) + assert schema.values == {"new_name": str} + + +def test_mutate_new_signal(): + schema = SignalSchema({"name": str}) + schema = schema.mutate({"age": Column("age", Float)}) + assert schema.values == {"name": str, "age": float} + + +def test_mutate_change_type(): + schema = SignalSchema({"name": str, "age": float, "f": File}) + schema = schema.mutate({"age": int, "f": TextFile}) + assert schema.values == {"name": str, "age": int, "f": TextFile}