Skip to content

Commit

Permalink
Use 'sql_to_python' for GenericFunction type conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Oct 1, 2024
1 parent d49f4d5 commit 60c5392
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 41 deletions.
20 changes: 8 additions & 12 deletions src/datachain/lib/convert/sql_to_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
from sqlalchemy import ColumnElement


def sql_to_python(args_map: dict[str, ColumnElement]) -> dict[str, Any]:
res = {}
for name, sql_exp in args_map.items():
try:
type_ = sql_exp.type.python_type
if type_ == Decimal:
type_ = float
except NotImplementedError:
type_ = str
res[name] = type_

return res
def sql_to_python(sql_exp: ColumnElement) -> Any:
try:
type_ = sql_exp.type.python_type
if type_ == Decimal:
type_ = float
except NotImplementedError:
type_ = str
return type_
13 changes: 8 additions & 5 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sqlalchemy.sql.sqltypes import NullType

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.convert.sql_to_python import sql_to_python
from datachain.lib.convert.values_to_tuples import values_to_tuples
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
from datachain.lib.dataset_info import DatasetInfo
Expand Down Expand Up @@ -1038,7 +1039,9 @@ def group_by(
```
"""
partition_by = (
partition_by if isinstance(partition_by, (list, tuple)) else [partition_by]
[partition_by]
if isinstance(partition_by, (str, GenericFunction))
else partition_by
)
if not partition_by:
raise ValueError("At least one column should be provided for partition_by")
Expand All @@ -1053,14 +1056,14 @@ def group_by(
)

schema_columns = self.signals_schema.db_columns_types()
schema_fields = {}
schema_fields: dict[str, DataType] = {}

# validate partition_by columns and add them to the schema
partition_by_columns: list[Union[Column, GenericFunction]] = []
for col in partition_by:
if isinstance(col, GenericFunction):
partition_by_columns.append(col)
schema_fields[col.name] = col.type
schema_fields[col.name] = sql_to_python(col)

Check warning on line 1066 in src/datachain/lib/dc.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/dc.py#L1065-L1066

Added lines #L1065 - L1066 were not covered by tests
else:
col_name = col.replace(".", DEFAULT_DELIMITER)
col_type = schema_columns.get(col_name)
Expand All @@ -1083,8 +1086,8 @@ def group_by(
)
if result_type is None:
result_type = col_type
col = Column(func.col, python_to_sql(col_type))
signal_columns.append(func.inner(col).label(col_name))
func_col = Column(func.col, python_to_sql(col_type))
signal_columns.append(func.inner(func_col).label(col_name))

if result_type is None:
raise DataChainColumnError(

Check warning on line 1093 in src/datachain/lib/dc.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/dc.py#L1093

Added line #L1093 was not covered by tests
Expand Down
8 changes: 4 additions & 4 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def _set_file_stream(
if ModelStore.is_pydantic(finfo.annotation):
SignalSchema._set_file_stream(getattr(obj, field), catalog, cache)

def db_columns_types(self) -> dict[str, type]:
def db_columns_types(self) -> dict[str, DataType]:
return {
DEFAULT_DELIMITER.join(path): _type
for path, _type, has_subtree, _ in self.get_flat_tree()
Expand Down Expand Up @@ -497,7 +497,7 @@ def mutate(self, args_map: dict) -> "SignalSchema":
new_values[name] = args_map[name]
else:
# adding new signal
new_values.update(sql_to_python({name: value}))
new_values[name] = sql_to_python(value)

return SignalSchema(new_values)

Expand Down Expand Up @@ -541,12 +541,12 @@ def _build_tree(
for name, val in values.items()
}

def get_flat_tree(self) -> Iterator[tuple[list[str], type, bool, int]]:
def get_flat_tree(self) -> Iterator[tuple[list[str], DataType, bool, int]]:
yield from self._get_flat_tree(self.tree, [], 0)

def _get_flat_tree(
self, tree: dict, prefix: list[str], depth: int
) -> Iterator[tuple[list[str], type, bool, int]]:
) -> Iterator[tuple[list[str], DataType, bool, int]]:
for name, (type_, substree) in tree.items():
suffix = name.split(".")
new_prefix = prefix + suffix
Expand Down
37 changes: 17 additions & 20 deletions tests/unit/lib/test_sql_to_python.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
from sqlalchemy.sql.sqltypes import NullType

from datachain import Column
Expand All @@ -6,23 +7,19 @@
from datachain.sql.types import Float, Int64, String


def test_sql_columns_to_python_types():
assert sql_to_python(
{
"name": Column("name", String),
"age": Column("age", Int64),
"score": Column("score", Float),
}
) == {"name": str, "age": int, "score": float}


def test_sql_expression_to_python_types():
assert sql_to_python({"age": Column("age", Int64) - 2}) == {"age": int}


def test_sql_function_to_python_types():
assert sql_to_python({"age": func.avg(Column("age", Int64))}) == {"age": float}


def test_sql_to_python_types_default_type():
assert sql_to_python({"null": Column("null", NullType)}) == {"null": str}
@pytest.mark.parametrize(
"sql_column, expected",
[
(Column("name", String), str),
(Column("age", Int64), int),
(Column("score", Float), float),
# SQL expression
(Column("age", Int64) - 2, int),
# SQL function
(func.avg(Column("age", Int64)), float),
# Default type
(Column("null", NullType), str),
],
)
def test_sql_columns_to_python_types(sql_column, expected):
assert sql_to_python(sql_column) == expected

0 comments on commit 60c5392

Please sign in to comment.