Skip to content

Commit

Permalink
Add missing aggregate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Oct 14, 2024
1 parent a00a1e8 commit 56999e8
Show file tree
Hide file tree
Showing 14 changed files with 602 additions and 412 deletions.
87 changes: 18 additions & 69 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from sqlalchemy.sql.sqltypes import NullType

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.convert.sql_to_python import sql_to_python
from datachain.lib.convert.values_to_tuples import values_to_tuples
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
from datachain.lib.dataset_info import DatasetInfo
Expand All @@ -44,21 +43,12 @@
from datachain.lib.model_store import ModelStore
from datachain.lib.settings import Settings
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf import (
Aggregator,
BatchMapper,
Generator,
Mapper,
UDFBase,
)
from datachain.lib.udf import Aggregator, BatchMapper, Generator, Mapper, UDFBase
from datachain.lib.udf_signature import UdfSignature
from datachain.lib.utils import DataChainParamsError
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
from datachain.query import Session
from datachain.query.dataset import (
DatasetQuery,
PartitionByType,
)
from datachain.query.schema import DEFAULT_DELIMITER, Column
from datachain.query.dataset import DatasetQuery, PartitionByType
from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
from datachain.sql.functions import path as pathfunc
from datachain.telemetry import telemetry
from datachain.utils import batched_it, inside_notebook
Expand Down Expand Up @@ -151,11 +141,6 @@ def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str:
super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")


class DataChainColumnError(DataChainParamsError): # noqa: D101
def __init__(self, col_name, msg): # noqa: D107
super().__init__(f"Error for column {col_name}: {msg}")


OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]


Expand Down Expand Up @@ -1015,35 +1000,22 @@ def select_except(self, *args: str) -> "Self":
def group_by(
self,
*,
partition_by: Union[
Union[str, GenericFunction], Sequence[Union[str, GenericFunction]]
],
partition_by: Union[str, Sequence[str]],
**kwargs: Func,
) -> "Self":
"""Group rows by specified set of signals and return new signals
with aggregated values.
Example:
Using column name(s) in partition_by:
```py
chain = chain.group_by(
cnt=func.count(),
partition_by=("file_source", "file_ext"),
)
Using GenericFunction in partition_by:
```py
chain = chain.group_by(
total_size=func.sum("file.size"),
partition_by=func.file_ext(C("file.path")),
)
```
"""
partition_by = (
[partition_by]
if isinstance(partition_by, (str, GenericFunction))
else partition_by
)
if isinstance(partition_by, str):
partition_by = [partition_by]
if not partition_by:
raise ValueError("At least one column should be provided for partition_by")

Expand All @@ -1056,46 +1028,23 @@ def group_by(
f"Column {col_name} has type {type(func)} but expected Func object",
)

schema_columns = self.signals_schema.db_columns_types()
partition_by_columns: list[Column] = []
signal_columns: list[Column] = []
schema_fields: dict[str, DataType] = {}

# validate partition_by columns and add them to the schema
partition_by_columns: list[Union[Column, GenericFunction]] = []
for col in partition_by:
if isinstance(col, GenericFunction):
partition_by_columns.append(col)
schema_fields[col.name] = sql_to_python(col)
else:
col_name = col.replace(".", DEFAULT_DELIMITER)
col_type = schema_columns.get(col_name)
if col_type is None:
raise DataChainColumnError(col, f"Column {col} not found in schema")
partition_by_columns.append(Column(col_name, python_to_sql(col_type)))
schema_fields[col_name] = col_type
for col_name in partition_by:
col_db_name = ColumnMeta.to_db_name(col_name)
col_type = self.signals_schema.get_column_type(col_db_name)
col = Column(col_db_name, python_to_sql(col_type))
partition_by_columns.append(col)
schema_fields[col_db_name] = col_type

# validate signal columns and add them to the schema
signal_columns: list[Column] = []
for col_name, func in kwargs.items():
result_type = func.result_type
if func.col is None:
signal_columns.append(func.inner().label(col_name))
else:
col_type = schema_columns.get(func.col)
if col_type is None:
raise DataChainColumnError(
func.col, f"Column {func.col} not found in schema"
)
if result_type is None:
result_type = col_type
func_col = Column(func.col, python_to_sql(col_type))
signal_columns.append(func.inner(func_col).label(col_name))

if result_type is None:
raise DataChainColumnError(
col_name, f"Cannot infer type for function {func}"
)

schema_fields[col_name] = result_type
col = func.get_column(self.signals_schema, label=col_name)
signal_columns.append(col)
schema_fields[col_name] = func.get_result_type(self.signals_schema)

return self._evolve(
query=self._query.group_by(signal_columns, partition_by_columns),
Expand Down
50 changes: 0 additions & 50 deletions src/datachain/lib/func.py

This file was deleted.

14 changes: 14 additions & 0 deletions src/datachain/lib/func/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .aggregate import any_value, avg, collect, concat, count, max, min, sum
from .func import Func

__all__ = [
"Func",
"any_value",
"avg",
"collect",
"concat",
"count",
"max",
"min",
"sum",
]
42 changes: 42 additions & 0 deletions src/datachain/lib/func/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Optional

from sqlalchemy import func as sa_func

from datachain.sql import functions as dc_func

from .func import Func


def count(col: Optional[str] = None) -> Func:
return Func(inner=sa_func.count, col=col, result_type=int)


def sum(col: str) -> Func:
return Func(inner=sa_func.sum, col=col)


def avg(col: str) -> Func:
return Func(inner=dc_func.aggregate.avg, col=col)


def min(col: str) -> Func:
return Func(inner=sa_func.min, col=col)


def max(col: str) -> Func:
return Func(inner=sa_func.max, col=col)


def any_value(col: str) -> Func:
return Func(inner=dc_func.aggregate.any_value, col=col)


def collect(col: str) -> Func:
return Func(inner=dc_func.aggregate.collect, col=col, is_array=True)


def concat(col: str, separator="") -> Func:
def inner(arg):
return dc_func.aggregate.group_concat(arg, separator)

return Func(inner=inner, col=col, result_type=str)
64 changes: 64 additions & 0 deletions src/datachain/lib/func/func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import TYPE_CHECKING, Callable, Optional

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.utils import DataChainColumnError
from datachain.query.schema import Column, ColumnMeta

if TYPE_CHECKING:
from datachain import DataType
from datachain.lib.signal_schema import SignalSchema


class Func:
def __init__(
self,
inner: Callable,
col: Optional[str] = None,
result_type: Optional["DataType"] = None,
is_array: bool = False,
) -> None:
self.inner = inner
self.col = col
self.result_type = result_type
self.is_array = is_array

@property
def db_col(self) -> Optional[str]:
return ColumnMeta.to_db_name(self.col) if self.col else None

def db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]:
if not self.db_col:
return None
col_type: type = signals_schema.get_column_type(self.db_col)
return list[col_type] if self.is_array else col_type # type: ignore[valid-type]

def get_result_type(self, signals_schema: "SignalSchema") -> "DataType":
col_type = self.db_col_type(signals_schema)

if self.result_type:
return self.result_type

if col_type:
return col_type

raise DataChainColumnError(

Check warning on line 44 in src/datachain/lib/func/func.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/func/func.py#L44

Added line #L44 was not covered by tests
str(self.inner),
"Column name is required to infer result type",
)

def get_column(
self, signals_schema: "SignalSchema", label: Optional[str] = None
) -> Column:
if self.col:
if label == "collect":
print(label)
col_type = self.get_result_type(signals_schema)
col = Column(self.db_col, python_to_sql(col_type))
func_col = self.inner(col)
else:
func_col = self.inner()

if label:
func_col = func_col.label(label)

return func_col
11 changes: 5 additions & 6 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,11 @@ def _set_file_stream(
if ModelStore.is_pydantic(finfo.annotation):
SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)

def db_columns_types(self) -> dict[str, DataType]:
return {
DEFAULT_DELIMITER.join(path): _type
for path, _type, has_subtree, _ in self.get_flat_tree()
if not has_subtree
}
def get_column_type(self, col_name: str) -> DataType:
for path, _type, has_subtree, _ in self.get_flat_tree():
if not has_subtree and DEFAULT_DELIMITER.join(path) == col_name:
return _type
raise SignalResolvingError([col_name], "is not found")

def db_signals(
self, name: Optional[str] = None, as_columns=False
Expand Down
5 changes: 5 additions & 0 deletions src/datachain/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ def __init__(self, message):
class DataChainParamsError(DataChainError):
def __init__(self, message):
super().__init__(message)


class DataChainColumnError(DataChainParamsError):
def __init__(self, col_name, msg):
super().__init__(f"Error for column {col_name}: {msg}")
2 changes: 1 addition & 1 deletion src/datachain/sql/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlalchemy.sql.expression import func

from . import array, path, string
from .array import avg
from .aggregate import avg
from .conditional import greatest, least
from .random import rand

Expand Down
Loading

0 comments on commit 56999e8

Please sign in to comment.