diff --git a/src/datachain/func/__init__.py b/src/datachain/func/__init__.py index b7cea4b46..fc7249e0f 100644 --- a/src/datachain/func/__init__.py +++ b/src/datachain/func/__init__.py @@ -16,7 +16,7 @@ sum, ) from .array import cosine_distance, euclidean_distance, length, sip_hash_64 -from .conditional import case, greatest, ifelse, least +from .conditional import case, greatest, ifelse, isnone, least from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64 from .random import rand from .string import byte_hamming_distance @@ -42,6 +42,7 @@ "greatest", "ifelse", "int_hash_64", + "isnone", "least", "length", "literal", diff --git a/src/datachain/func/conditional.py b/src/datachain/func/conditional.py index 363e9d216..6f7457c54 100644 --- a/src/datachain/func/conditional.py +++ b/src/datachain/func/conditional.py @@ -1,14 +1,15 @@ -from typing import Union +from typing import Optional, Union +from sqlalchemy import ColumnElement from sqlalchemy import case as sql_case -from sqlalchemy.sql.elements import BinaryExpression from datachain.lib.utils import DataChainParamsError +from datachain.query.schema import Column from datachain.sql.functions import conditional from .func import ColT, Func -CaseT = Union[int, float, complex, bool, str] +CaseT = Union[int, float, complex, bool, str, Func] def greatest(*args: Union[ColT, float]) -> Func: @@ -87,17 +88,21 @@ def least(*args: Union[ColT, float]) -> Func: ) -def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func: +def case( + *args: tuple[Union[ColumnElement, Func], CaseT], else_: Optional[CaseT] = None +) -> Func: """ Returns the case function that produces case expression which has a list of - conditions and corresponding results. Results can only be python primitives - like string, numbes or booleans. Result type is inferred from condition results. + conditions and corresponding results. Results can be python primitives like string, + numbers or booleans but can also be other nested function (including case function). + Result type is inferred from condition results. Args: - args (tuple(BinaryExpression, value(str | int | float | complex | bool): - - Tuple of binary expression and values pair which corresponds to one - case condition - value - else_ (str | int | float | complex | bool): else value in case expression + args (tuple((ColumnElement, Func), (str | int | float | complex | bool, Func))): + Tuple of condition and values pair. + else_ (str | int | float | complex | bool, Func): optional else value in case + expression. If omitted, and no case conditions are satisfied, the result + will be None (NULL in DB). Returns: Func: A Func object that represents the case function. @@ -111,15 +116,24 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func: """ supported_types = [int, float, complex, str, bool] - type_ = type(else_) if else_ else None + def _get_type(val): + if isinstance(val, Func): + # nested functions + return val.result_type + return type(val) if not args: raise DataChainParamsError("Missing statements") + type_ = _get_type(else_) if else_ is not None else None + for arg in args: - if type_ and not isinstance(arg[1], type_): - raise DataChainParamsError("Statement values must be of the same type") - type_ = type(arg[1]) + arg_type = _get_type(arg[1]) + if type_ and arg_type != type_: + raise DataChainParamsError( + f"Statement values must be of the same type, got {type_} and {arg_type}" + ) + type_ = arg_type if type_ not in supported_types: raise DataChainParamsError( @@ -127,20 +141,25 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func: ) kwargs = {"else_": else_} - return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_) + + return Func("case", inner=sql_case, cols=args, kwargs=kwargs, result_type=type_) -def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func: +def ifelse( + condition: Union[ColumnElement, Func], if_val: CaseT, else_val: CaseT +) -> Func: """ Returns the ifelse function that produces if expression which has a condition - and values for true and false outcome. Results can only be python primitives - like string, numbes or booleans. Result type is inferred from the values. + and values for true and false outcome. Results can be one of python primitives + like string, numbers or booleans, but can also be nested functions. + Result type is inferred from the values. Args: - condition: BinaryExpression - condition which is evaluated - if_val: (str | int | float | complex | bool): value for true condition outcome - else_val: (str | int | float | complex | bool): value for false condition - outcome + condition (ColumnElement, Func): Condition which is evaluated. + if_val (str | int | float | complex | bool, Func): Value for true + condition outcome. + else_val (str | int | float | complex | bool, Func): Value for false condition + outcome. Returns: Func: A Func object that represents the ifelse function. @@ -148,8 +167,33 @@ def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func: Example: ```py dc.mutate( - res=func.ifelse(C("num") > 0, "P", "N"), + res=func.ifelse(isnone("col"), "EMPTY", "NOT_EMPTY") ) ``` """ return case((condition, if_val), else_=else_val) + + +def isnone(col: Union[str, Column]) -> Func: + """ + Returns True if column value is None, otherwise False. + + Args: + col (str | Column): Column to check if it's None or not. + If a string is provided, it is assumed to be the name of the column. + + Returns: + Func: A Func object that represents the conditional to check if column is None. + + Example: + ```py + dc.mutate(test=ifelse(isnone("col"), "EMPTY", "NOT_EMPTY")) + ``` + """ + from datachain import C + + if isinstance(col, str): + # if string, it is assumed to be the name of the column + col = C(col) + + return case((col.is_(None) if col is not None else True, True), else_=False) diff --git a/src/datachain/func/func.py b/src/datachain/func/func.py index 90ee5796e..072519df2 100644 --- a/src/datachain/func/func.py +++ b/src/datachain/func/func.py @@ -23,7 +23,7 @@ from .window import Window -ColT = Union[str, ColumnElement, "Func"] +ColT = Union[str, ColumnElement, "Func", tuple] class Func(Function): @@ -78,7 +78,7 @@ def _db_cols(self) -> Sequence[ColT]: return ( [ col - if isinstance(col, (Func, BindParameter, Case, Comparator)) + if isinstance(col, (Func, BindParameter, Case, Comparator, tuple)) else ColumnMeta.to_db_name( col.name if isinstance(col, ColumnElement) else col ) @@ -381,17 +381,24 @@ def get_column( col_type = self.get_result_type(signals_schema) sql_type = python_to_sql(col_type) - def get_col(col: ColT) -> ColT: + def get_col(col: ColT, string_as_literal=False) -> ColT: + # string_as_literal is used only for conditionals like `case()` where + # literals are nested inside ColT as we have tuples of condition - values + # and if user wants to set some case value as column, explicit `C("col")` + # syntax must be used to distinguish from literals + if isinstance(col, tuple): + return tuple(get_col(x, string_as_literal=True) for x in col) if isinstance(col, Func): return col.get_column(signals_schema, table=table) - if isinstance(col, str): + if isinstance(col, str) and not string_as_literal: column = Column(col, sql_type) column.table = table return column return col cols = [get_col(col) for col in self._db_cols] - func_col = self.inner(*cols, *self.args, **self.kwargs) + kwargs = {k: get_col(v, string_as_literal=True) for k, v in self.kwargs.items()} + func_col = self.inner(*cols, *self.args, **kwargs) if self.is_window: if not self.window: @@ -416,6 +423,11 @@ def get_col(col: ColT) -> ColT: def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType": + if isinstance(col, tuple): + raise DataChainParamsError( + "Cannot get type from tuple, please provide type hint to the function" + ) + if isinstance(col, Func): return col.get_result_type(signals_schema) diff --git a/tests/unit/sql/test_conditional.py b/tests/unit/sql/test_conditional.py index b1c4f59bd..e78cdcb24 100644 --- a/tests/unit/sql/test_conditional.py +++ b/tests/unit/sql/test_conditional.py @@ -86,6 +86,19 @@ def test_case(warehouse, val, expected): assert result == ((expected,),) +@pytest.mark.parametrize( + "val,expected", + [ + (1, "A"), + (2, None), + ], +) +def test_case_without_else(warehouse, val, expected): + query = select(func.case(*[(val < 2, "A")])) + result = tuple(warehouse.db.execute(query)) + assert result == ((expected,),) + + def test_case_missing_statements(warehouse): with pytest.raises(DataChainParamsError) as exc_info: select(func.case(*[], else_="D")) @@ -96,7 +109,9 @@ def test_case_not_same_result_types(warehouse): val = 2 with pytest.raises(DataChainParamsError) as exc_info: select(func.case(*[(val > 1, "A"), (2 < val < 4, 5)], else_="D")) - assert str(exc_info.value) == "Statement values must be of the same type" + assert str(exc_info.value) == ( + "Statement values must be of the same type, got and " + ) def test_case_wrong_result_type(warehouse): @@ -124,3 +139,18 @@ def test_ifelse(warehouse, val, expected): query = select(func.ifelse(val <= 3, "L", "H")) result = tuple(warehouse.db.execute(query)) assert result == ((expected,),) + + +@pytest.mark.parametrize( + "val,expected", + [ + [None, True], + [func.literal("abcd"), False], + ], +) +def test_isnone(warehouse, val, expected): + from datachain.func.conditional import isnone + + query = select(isnone(val)) + result = tuple(warehouse.db.execute(query)) + assert result == ((expected,),) diff --git a/tests/unit/test_func.py b/tests/unit/test_func.py index ad6d062b0..4b4237f12 100644 --- a/tests/unit/test_func.py +++ b/tests/unit/test_func.py @@ -8,6 +8,7 @@ case, ifelse, int_hash_64, + isnone, literal, ) from datachain.func.random import rand @@ -18,6 +19,7 @@ sqlite_byte_hamming_distance, sqlite_int_hash_64, ) +from tests.utils import skip_if_not_sqlite @pytest.fixture() @@ -663,6 +665,59 @@ def test_case_mutate(dc, val, else_, type_): assert res.schema["test"] == type_ +@pytest.mark.parametrize( + "val,else_,type_", + [ + ["A", "D", str], + [1, 2, int], + [1.5, 2.5, float], + [True, False, bool], + ], +) +def test_nested_case_on_condition_mutate(dc, val, else_, type_): + res = dc.mutate( + test=case((case((C("num") < 2, True), else_=False), val), else_=else_) + ) + assert list(res.order_by("test").collect("test")) == sorted( + [val, else_, else_, else_, else_] + ) + assert res.schema["test"] == type_ + + +@pytest.mark.parametrize( + "v1,v2,v3,type_", + [ + ["A", "B", "C", str], + [1, 2, 3, int], + [1.5, 2.5, 3.5, float], + [False, True, True, bool], + ], +) +def test_nested_case_on_value_mutate(dc, v1, v2, v3, type_): + res = dc.mutate( + test=case((C("num") < 4, case((C("num") < 2, v1), else_=v2)), else_=v3) + ) + assert list(res.order_by("num").collect("test")) == sorted([v1, v2, v2, v3, v3]) + assert res.schema["test"] == type_ + + +@pytest.mark.parametrize( + "v1,v2,v3,type_", + [ + ["A", "B", "C", str], + [1, 2, 3, int], + [1.5, 2.5, 3.5, float], + [False, True, True, bool], + ], +) +def test_nested_case_on_else_mutate(dc, v1, v2, v3, type_): + res = dc.mutate( + test=case((C("num") < 3, v1), else_=case((C("num") < 4, v2), else_=v3)) + ) + assert list(res.order_by("num").collect("test")) == sorted([v1, v1, v2, v3, v3]) + assert res.schema["test"] == type_ + + @pytest.mark.parametrize( "if_val,else_val,type_", [ @@ -678,3 +733,31 @@ def test_ifelse_mutate(dc, if_val, else_val, type_): [if_val, else_val, else_val, else_val, else_val] ) assert res.schema["test"] == type_ + + +@pytest.mark.parametrize("col", ["val", C("val")]) +@skip_if_not_sqlite +def test_isnone_mutate(col): + dc = DataChain.from_values( + num=list(range(1, 6)), + val=[None if i > 3 else "A" for i in range(1, 6)], + ) + + res = dc.mutate(test=isnone(col)) + assert list(res.order_by("test").collect("test")) == sorted( + [False, False, False, True, True] + ) + assert res.schema["test"] is bool + + +@pytest.mark.parametrize("col", [C("val"), "val"]) +@skip_if_not_sqlite +def test_isnone_with_ifelse_mutate(col): + dc = DataChain.from_values( + num=list(range(1, 6)), + val=[None if i > 3 else "A" for i in range(1, 6)], + ) + + res = dc.mutate(test=ifelse(isnone(col), "NONE", "NOT_NONE")) + assert list(res.order_by("num").collect("test")) == ["NOT_NONE"] * 3 + ["NONE"] * 2 + assert res.schema["test"] is str