Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Sep 11, 2024
1 parent bc094bf commit bbb129b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,7 @@ def create_new_dataset_version(
if create_rows_table:
table_name = self.warehouse.dataset_table_name(dataset.name, version)
self.warehouse.create_dataset_rows_table(table_name, columns=columns)
self.update_dataset_version_with_warehouse_info(dataset, version)

return dataset

Expand Down
10 changes: 6 additions & 4 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,14 @@ def dataset_stats(
expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
sa.func.count(table.c.sys__id),
)
for c in table.columns:
if c.name.endswith("file__size"):
expressions = (*expressions, sa.func.sum(c))
size_columns = [
c for c in table.columns if c.name == "size" or c.name.endswith("__size")
]
if size_columns:
expressions = (*expressions, sa.func.sum(sum(size_columns)))
query = select(*expressions)
((nrows, *rest),) = self.db.execute(query)
return nrows, sum(rest) if rest else 0
return nrows, rest[0] if rest else 0

def prepare_entries(
self, uri: str, entries: Iterable[Entry]
Expand Down
25 changes: 25 additions & 0 deletions tests/func/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import yaml
from fsspec.implementations.local import LocalFileSystem

from datachain import DataChain, File
from datachain.catalog import parse_edatachain_file
from datachain.cli import garbage_collect
from datachain.error import (
Expand Down Expand Up @@ -942,6 +943,30 @@ def test_query_save_size(cloud_test_catalog, mock_popen_dataset_created):
assert dataset_version.size == 15


def test_dataset_stats(test_session):
ids = [1, 2, 3]
values = tuple(zip(["a", "b", "c"], [1, 2, 3]))

ds1 = DataChain.from_values(
ids=ids,
file=[File(path=name, size=size) for name, size in values],
session=test_session,
).save()
dataset_version1 = test_session.catalog.get_dataset(ds1.name).get_version(1)
assert dataset_version1.num_objects == 3
assert dataset_version1.size == 6

ds2 = DataChain.from_values(
ids=ids,
file1=[File(path=name, size=size) for name, size in values],
file2=[File(path=name, size=size * 2) for name, size in values],
session=test_session,
).save()
dataset_version2 = test_session.catalog.get_dataset(ds2.name).get_version(1)
assert dataset_version2.num_objects == 3
assert dataset_version2.size == 18


def test_query_fail_to_compile(cloud_test_catalog):
catalog = cloud_test_catalog.catalog

Expand Down

0 comments on commit bbb129b

Please sign in to comment.