Skip to content

Commit

Permalink
Add more group_by aggregate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Oct 1, 2024
1 parent 8454938 commit 89ec9d0
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 108 deletions.
64 changes: 31 additions & 33 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,62 +1018,60 @@ def group_by(
**kwargs: Func,
) -> "Self":
"""Groups by specified set of signals."""
if not kwargs:
raise ValueError("At least one column should be provided for group_by")

partition_by = [partition_by] if isinstance(partition_by, str) else partition_by
if not partition_by:
raise ValueError("At least one column should be provided for partition_by")

all_columns = {
DEFAULT_DELIMITER.join(path): _type
for path, _type, has_subtree, _ in self.signals_schema.get_flat_tree()
if not has_subtree
}
if not kwargs:
raise ValueError("At least one column should be provided for group_by")
for col_name, func in kwargs.items():
if not isinstance(func, Func):
raise DataChainColumnError(
col_name,
f"Column {col_name} has type {type(func)} but expected Func object",
)

partition_by_columns = []
schema_columns = self.signals_schema.db_columns_types()
schema_fields = {}

# validate partition_by columns and add them to the schema
partition_by_columns: list[Column] = []
for col_name in partition_by:
col_type = all_columns.get(col_name)
db_col_name = col_name.replace(".", DEFAULT_DELIMITER)
col_type = schema_columns.get(db_col_name)
if col_type is None:
raise DataChainColumnError(
col_name, f"Column {col_name} not found in schema"
)
column = Column(col_name, python_to_sql(col_type))
partition_by_columns.append(column)
schema_fields[col_name] = col_type
partition_by_columns.append(Column(db_col_name, python_to_sql(col_type)))
schema_fields[db_col_name] = col_type

select_columns = []
for field, func in kwargs.items():
cols = []
# validate signal columns and add them to the schema
signal_columns: list[Column] = []
for col_name, func in kwargs.items():
result_type = func.result_type
for col_name in func.cols:
col_type = all_columns.get(col_name)
if func.col is None:
signal_columns.append(func.inner().label(col_name))
else:
col_type = schema_columns.get(func.col)
if col_type is None:
raise DataChainColumnError(
col_name, f"Column {col_name} not found in schema"
func.col, f"Column {func.col} not found in schema"
)
cols.append(Column(col_name, python_to_sql(col_type)))
if result_type is None:
result_type = col_type
elif col_type != result_type:
raise DataChainColumnError(
col_name,
(
f"Column {col_name} has type {col_type}"
f"but expected {result_type}"
),
)
col = Column(func.col, python_to_sql(col_type))
signal_columns.append(func.inner(col).label(col_name))

if result_type is None:
raise ValueError(
f"Cannot infer type for function {func} with columns {func.cols}"
raise DataChainColumnError(

Check warning on line 1067 in src/datachain/lib/dc.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/dc.py#L1067

Added line #L1067 was not covered by tests
col_name, f"Cannot infer type for function {func}"
)

select_columns.append(func.inner(*cols).label(field))
schema_fields[field] = result_type
schema_fields[col_name] = result_type

return self._evolve(
query=self._query.group_by(select_columns, partition_by_columns),
query=self._query.group_by(signal_columns, partition_by_columns),
signal_schema=SignalSchema(schema_fields),
)

Expand Down
36 changes: 29 additions & 7 deletions src/datachain/lib/func.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import TYPE_CHECKING, Callable, Optional

from sqlalchemy import func
from sqlalchemy import func as sa_func

from datachain.query.schema import DEFAULT_DELIMITER
from datachain.sql import functions as dc_func

if TYPE_CHECKING:
from datachain import DataType
Expand All @@ -12,17 +15,36 @@ class Func:
def __init__(
self,
inner: Callable,
cols: tuple[str, ...],
col: Optional[str] = None,
result_type: Optional["DataType"] = None,
) -> None:
self.inner = inner
self.cols = [col.replace(".", "__") for col in cols]
self.col = col.replace(".", DEFAULT_DELIMITER) if col else None
self.result_type = result_type


def sum(*cols: str) -> Func:
return Func(inner=func.sum, cols=cols)
def count(col: Optional[str] = None) -> Func:
return Func(inner=sa_func.count, col=col, result_type=int)


def sum(col: str) -> Func:
return Func(inner=sa_func.sum, col=col)


def avg(col: str) -> Func:
return Func(inner=dc_func.avg, col=col)


def min(col: str) -> Func:
return Func(inner=sa_func.min, col=col)


def max(col: str) -> Func:
return Func(inner=sa_func.max, col=col)


def concat(col: str, separator="") -> Func:
def inner(arg):
return sa_func.aggregate_strings(arg, separator)

def count(*cols: str) -> Func:
return Func(inner=func.count, cols=cols, result_type=int)
return Func(inner=inner, col=col, result_type=str)
7 changes: 7 additions & 0 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,13 @@ def _set_file_stream(
if ModelStore.is_pydantic(finfo.annotation):
SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)

def db_columns_types(self) -> dict[str, type]:
return {
DEFAULT_DELIMITER.join(path): _type
for path, _type, has_subtree, _ in self.get_flat_tree()
if not has_subtree
}

def db_signals(
self, name: Optional[str] = None, as_columns=False
) -> Union[list[str], list[Column]]:
Expand Down
39 changes: 5 additions & 34 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,39 +962,21 @@ class SQLGroupBy(SQLClause):
group_by: Sequence[Union[str, ColumnElement]]

def apply_sql_clause(self, query) -> Select:
if not self.cols:
raise ValueError("No columns to select")

Check warning on line 966 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L966

Added line #L966 was not covered by tests
if not self.group_by:
raise ValueError("No columns to group by")

Check warning on line 968 in src/datachain/query/dataset.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/query/dataset.py#L968

Added line #L968 was not covered by tests

subquery = query.subquery()

cols = [
subquery.c[str(c)] if isinstance(c, (str, C)) else c
for c in [*self.group_by, *self.cols]
]
if not cols:
cols = subquery.c

return sqlalchemy.select(*cols).select_from(subquery).group_by(*self.group_by)


@frozen
class GroupBy(Step):
"""Group rows by a specific column."""

cols: PartitionByType

def clone(self) -> "Self":
return self.__class__(self.cols)

def apply(
self, query_generator: QueryGenerator, temp_tables: list[str]
) -> StepResult:
query = query_generator.select()
grouped_query = query.group_by(*self.cols)

def q(*columns):
return grouped_query.with_only_columns(*columns)

return step_result(q, grouped_query.selected_columns)


def _validate_columns(
left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
) -> set[str]:
Expand Down Expand Up @@ -1148,25 +1130,14 @@ def apply_steps(self) -> QueryGenerator:
query.steps = query.steps[-1:] + query.steps[:-1]

result = query.starting_step.apply()
group_by = None
self.dependencies.update(result.dependencies)

for step in query.steps:
if isinstance(step, GroupBy):
if group_by is not None:
raise TypeError("only one group_by allowed")
group_by = step
continue

result = step.apply(
result.query_generator, self.temp_table_names
) # a chain of steps linked by results
self.dependencies.update(result.dependencies)

if group_by:
result = group_by.apply(result.query_generator, self.temp_table_names)
self.dependencies.update(result.dependencies)

return result.query_generator

@staticmethod
Expand Down
Loading

0 comments on commit 89ec9d0

Please sign in to comment.