Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement chain group_by #482

Merged
merged 5 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datachain.lib import func
from datachain.lib.data_model import DataModel, DataType, is_chain_type
from datachain.lib.dc import C, Column, DataChain, Sys
from datachain.lib.file import (
Expand Down Expand Up @@ -34,6 +35,7 @@
"Sys",
"TarVFile",
"TextFile",
"func",
"is_chain_type",
"metrics",
"param",
Expand Down
8 changes: 8 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,14 @@ def copy_table(
query: Select,
progress_cb: Optional[Callable[[int], None]] = None,
) -> None:
if len(query._group_by_clause) > 0:
select_q = query.with_only_columns(
*[c for c in query.selected_columns if c.name != "sys__id"]
)
q = table.insert().from_select(list(select_q.selected_columns), select_q)
self.db.execute(q)
return

dreadatour marked this conversation as resolved.
Show resolved Hide resolved
if "sys__id" in query.selected_columns:
col_id = query.selected_columns.sys__id
else:
Expand Down
20 changes: 8 additions & 12 deletions src/datachain/lib/convert/sql_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
from sqlalchemy import ColumnElement


def sql_to_python(args_map: dict[str, ColumnElement]) -> dict[str, Any]:
res = {}
for name, sql_exp in args_map.items():
try:
type_ = sql_exp.type.python_type
if type_ == Decimal:
type_ = float
except NotImplementedError:
type_ = str
res[name] = type_

return res
def sql_to_python(sql_exp: ColumnElement) -> Any:
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
try:
type_ = sql_exp.type.python_type
if type_ == Decimal:
type_ = float
except NotImplementedError:
type_ = str
return type_
84 changes: 62 additions & 22 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from datachain.lib.dataset_info import DatasetInfo
from datachain.lib.file import ArrowRow, File, get_file_type
from datachain.lib.file import ExportPlacement as FileExportPlacement
from datachain.lib.func import Func
from datachain.lib.listing import (
is_listing_dataset,
is_listing_expired,
Expand All @@ -42,21 +43,12 @@
from datachain.lib.model_store import ModelStore
from datachain.lib.settings import Settings
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf import (
Aggregator,
BatchMapper,
Generator,
Mapper,
UDFBase,
)
from datachain.lib.udf import Aggregator, BatchMapper, Generator, Mapper, UDFBase
from datachain.lib.udf_signature import UdfSignature
from datachain.lib.utils import DataChainParamsError
from datachain.lib.utils import DataChainColumnError, DataChainParamsError
from datachain.query import Session
from datachain.query.dataset import (
DatasetQuery,
PartitionByType,
)
from datachain.query.schema import DEFAULT_DELIMITER, Column
from datachain.query.dataset import DatasetQuery, PartitionByType
from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
from datachain.sql.functions import path as pathfunc
from datachain.telemetry import telemetry
from datachain.utils import batched_it, inside_notebook
Expand Down Expand Up @@ -149,11 +141,6 @@ def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str:
super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}")


class DataChainColumnError(DataChainParamsError): # noqa: D101
def __init__(self, col_name, msg): # noqa: D107
super().__init__(f"Error for column {col_name}: {msg}")


OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]]


Expand Down Expand Up @@ -982,10 +969,9 @@ def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override]
row is left in the result set.

Example:
```py
dc.distinct("file.parent", "file.name")
)
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
```
```py
dc.distinct("file.parent", "file.name")
```
"""
return self._evolve(
query=self._query.distinct(
Expand All @@ -1011,6 +997,60 @@ def select_except(self, *args: str) -> "Self":
query=self._query.select(*columns), signal_schema=new_schema
)

def group_by(
self,
*,
partition_by: Union[str, Sequence[str]],
**kwargs: Func,
) -> "Self":
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
"""Group rows by specified set of signals and return new signals
with aggregated values.

Example:
```py
chain = chain.group_by(
cnt=func.count(),
partition_by=("file_source", "file_ext"),
)
```
"""
if isinstance(partition_by, str):
partition_by = [partition_by]
if not partition_by:
raise ValueError("At least one column should be provided for partition_by")

if not kwargs:
raise ValueError("At least one column should be provided for group_by")
for col_name, func in kwargs.items():
if not isinstance(func, Func):
raise DataChainColumnError(
col_name,
f"Column {col_name} has type {type(func)} but expected Func object",
)

partition_by_columns: list[Column] = []
signal_columns: list[Column] = []
schema_fields: dict[str, DataType] = {}

# validate partition_by columns and add them to the schema
for col_name in partition_by:
col_db_name = ColumnMeta.to_db_name(col_name)
col_type = self.signals_schema.get_column_type(col_db_name)
col = Column(col_db_name, python_to_sql(col_type))
partition_by_columns.append(col)
schema_fields[col_db_name] = col_type

# validate signal columns and add them to the schema
for col_name, func in kwargs.items():
col = func.get_column(self.signals_schema, label=col_name)
signal_columns.append(col)
schema_fields[col_name] = func.get_result_type(self.signals_schema)

return self._evolve(
query=self._query.group_by(signal_columns, partition_by_columns),
signal_schema=SignalSchema(schema_fields),
)

def mutate(self, **kwargs) -> "Self":
"""Create new signals based on existing signals.

Expand Down
14 changes: 14 additions & 0 deletions src/datachain/lib/func/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .aggregate import any_value, avg, collect, concat, count, max, min, sum
from .func import Func

__all__ = [
"Func",
"any_value",
"avg",
"collect",
"concat",
"count",
"max",
"min",
"sum",
]
42 changes: 42 additions & 0 deletions src/datachain/lib/func/aggregate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Optional

from sqlalchemy import func as sa_func

from datachain.sql import functions as dc_func

from .func import Func


def count(col: Optional[str] = None) -> Func:
return Func(inner=sa_func.count, col=col, result_type=int)


def sum(col: str) -> Func:
return Func(inner=sa_func.sum, col=col)


def avg(col: str) -> Func:
return Func(inner=dc_func.aggregate.avg, col=col)


def min(col: str) -> Func:
return Func(inner=sa_func.min, col=col)


def max(col: str) -> Func:
return Func(inner=sa_func.max, col=col)


def any_value(col: str) -> Func:
return Func(inner=dc_func.aggregate.any_value, col=col)


def collect(col: str) -> Func:
return Func(inner=dc_func.aggregate.collect, col=col, is_array=True)


def concat(col: str, separator="") -> Func:
def inner(arg):
return dc_func.aggregate.group_concat(arg, separator)

return Func(inner=inner, col=col, result_type=str)
64 changes: 64 additions & 0 deletions src/datachain/lib/func/func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import TYPE_CHECKING, Callable, Optional

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.utils import DataChainColumnError
from datachain.query.schema import Column, ColumnMeta

if TYPE_CHECKING:
from datachain import DataType
from datachain.lib.signal_schema import SignalSchema


class Func:
def __init__(
self,
inner: Callable,
col: Optional[str] = None,
result_type: Optional["DataType"] = None,
is_array: bool = False,
) -> None:
self.inner = inner
self.col = col
self.result_type = result_type
self.is_array = is_array

@property
def db_col(self) -> Optional[str]:
return ColumnMeta.to_db_name(self.col) if self.col else None

def db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]:
if not self.db_col:
return None
col_type: type = signals_schema.get_column_type(self.db_col)
return list[col_type] if self.is_array else col_type # type: ignore[valid-type]

def get_result_type(self, signals_schema: "SignalSchema") -> "DataType":
col_type = self.db_col_type(signals_schema)

if self.result_type:
return self.result_type

if col_type:
return col_type

raise DataChainColumnError(

Check warning on line 44 in src/datachain/lib/func/func.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/func/func.py#L44

Added line #L44 was not covered by tests
str(self.inner),
"Column name is required to infer result type",
)

def get_column(
self, signals_schema: "SignalSchema", label: Optional[str] = None
) -> Column:
if self.col:
if label == "collect":
print(label)
col_type = self.get_result_type(signals_schema)
col = Column(self.db_col, python_to_sql(col_type))
func_col = self.inner(col)
else:
func_col = self.inner()

if label:
func_col = func_col.label(label)

return func_col
12 changes: 9 additions & 3 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,12 @@ def _set_file_stream(
if ModelStore.is_pydantic(finfo.annotation):
SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)

def get_column_type(self, col_name: str) -> DataType:
for path, _type, has_subtree, _ in self.get_flat_tree():
if not has_subtree and DEFAULT_DELIMITER.join(path) == col_name:
return _type
raise SignalResolvingError([col_name], "is not found")

def db_signals(
self, name: Optional[str] = None, as_columns=False
) -> Union[list[str], list[Column]]:
Expand Down Expand Up @@ -490,7 +496,7 @@ def mutate(self, args_map: dict) -> "SignalSchema":
new_values[name] = args_map[name]
else:
# adding new signal
new_values.update(sql_to_python({name: value}))
new_values[name] = sql_to_python(value)
dreadatour marked this conversation as resolved.
Show resolved Hide resolved

return SignalSchema(new_values)

Expand Down Expand Up @@ -534,12 +540,12 @@ def _build_tree(
for name, val in values.items()
}

def get_flat_tree(self) -> Iterator[tuple[list[str], type, bool, int]]:
def get_flat_tree(self) -> Iterator[tuple[list[str], DataType, bool, int]]:
dreadatour marked this conversation as resolved.
Show resolved Hide resolved
yield from self._get_flat_tree(self.tree, [], 0)

def _get_flat_tree(
self, tree: dict, prefix: list[str], depth: int
) -> Iterator[tuple[list[str], type, bool, int]]:
) -> Iterator[tuple[list[str], DataType, bool, int]]:
for name, (type_, substree) in tree.items():
suffix = name.split(".")
new_prefix = prefix + suffix
Expand Down
5 changes: 5 additions & 0 deletions src/datachain/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ def __init__(self, message):
class DataChainParamsError(DataChainError):
def __init__(self, message):
super().__init__(message)


class DataChainColumnError(DataChainParamsError):
def __init__(self, col_name, msg):
super().__init__(f"Error for column {col_name}: {msg}")
Loading