From bbb129b9913128177544dbf2272f215010172309 Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Wed, 11 Sep 2024 22:39:46 +0700 Subject: [PATCH] Fix tests --- src/datachain/catalog/catalog.py | 1 + src/datachain/data_storage/warehouse.py | 10 ++++++---- tests/func/test_catalog.py | 25 +++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index 6d44cad15..084e68c56 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -1058,6 +1058,7 @@ def create_new_dataset_version( if create_rows_table: table_name = self.warehouse.dataset_table_name(dataset.name, version) self.warehouse.create_dataset_rows_table(table_name, columns=columns) + self.update_dataset_version_with_warehouse_info(dataset, version) return dataset diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index bfa9ad51e..dd1bf7130 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -404,12 +404,14 @@ def dataset_stats( expressions: tuple[_ColumnsClauseArgument[Any], ...] = ( sa.func.count(table.c.sys__id), ) - for c in table.columns: - if c.name.endswith("file__size"): - expressions = (*expressions, sa.func.sum(c)) + size_columns = [ + c for c in table.columns if c.name == "size" or c.name.endswith("__size") + ] + if size_columns: + expressions = (*expressions, sa.func.sum(sum(size_columns))) query = select(*expressions) ((nrows, *rest),) = self.db.execute(query) - return nrows, sum(rest) if rest else 0 + return nrows, rest[0] if rest else 0 def prepare_entries( self, uri: str, entries: Iterable[Entry] diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index 0373917d6..4c221d11d 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -8,6 +8,7 @@ import yaml from fsspec.implementations.local import LocalFileSystem +from datachain import DataChain, File from datachain.catalog import parse_edatachain_file from datachain.cli import garbage_collect from datachain.error import ( @@ -942,6 +943,30 @@ def test_query_save_size(cloud_test_catalog, mock_popen_dataset_created): assert dataset_version.size == 15 +def test_dataset_stats(test_session): + ids = [1, 2, 3] + values = tuple(zip(["a", "b", "c"], [1, 2, 3])) + + ds1 = DataChain.from_values( + ids=ids, + file=[File(path=name, size=size) for name, size in values], + session=test_session, + ).save() + dataset_version1 = test_session.catalog.get_dataset(ds1.name).get_version(1) + assert dataset_version1.num_objects == 3 + assert dataset_version1.size == 6 + + ds2 = DataChain.from_values( + ids=ids, + file1=[File(path=name, size=size) for name, size in values], + file2=[File(path=name, size=size * 2) for name, size in values], + session=test_session, + ).save() + dataset_version2 = test_session.catalog.get_dataset(ds2.name).get_version(1) + assert dataset_version2.num_objects == 3 + assert dataset_version2.size == 18 + + def test_query_fail_to_compile(cloud_test_catalog): catalog = cloud_test_catalog.catalog