Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added isnone() function #801

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/datachain/func/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
sum,
)
from .array import cosine_distance, euclidean_distance, length, sip_hash_64
from .conditional import case, greatest, ifelse, least
from .conditional import case, greatest, ifelse, isnone, least
from .numeric import bit_and, bit_hamming_distance, bit_or, bit_xor, int_hash_64
from .random import rand
from .string import byte_hamming_distance
Expand All @@ -42,6 +42,7 @@
"greatest",
"ifelse",
"int_hash_64",
"isnone",
"least",
"length",
"literal",
Expand Down
86 changes: 64 additions & 22 deletions src/datachain/func/conditional.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import Union
from typing import Optional, Union

from sqlalchemy import ColumnElement
from sqlalchemy import case as sql_case
from sqlalchemy.sql.elements import BinaryExpression

from datachain.lib.utils import DataChainParamsError
from datachain.query.schema import Column
from datachain.sql.functions import conditional

from .func import ColT, Func

CaseT = Union[int, float, complex, bool, str]
CaseT = Union[int, float, complex, bool, str, Func]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there tests for Func values btw? can we add them?


def greatest(*args: Union[ColT, float]) -> Func:
Expand Down Expand Up @@ -87,17 +88,19 @@ def least(*args: Union[ColT, float]) -> Func:
)


def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
def case(
*args: tuple[Union[ColumnElement, Func], CaseT], else_: Optional[CaseT] = None
) -> Func:
"""
Returns the case function that produces case expression which has a list of
conditions and corresponding results. Results can only be python primitives
like string, numbes or booleans. Result type is inferred from condition results.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it still true? can result now be a Func?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, all methods here return Func which is essentially a wrapper around sqlalchemy functions like case which produces case construct

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean this part Results can only be python primitives ..., sorry

like string, numbers or booleans. Result type is inferred from condition results.

Args:
args (tuple(BinaryExpression, value(str | int | float | complex | bool):
- Tuple of binary expression and values pair which corresponds to one
case condition - value
else_ (str | int | float | complex | bool): else value in case expression
args (tuple((ColumnElement, Func), (str | int | float | complex | bool, Func))):
Tuple of condition and values pair.
else_ (str | int | float | complex | bool, Func): else value in case
expression.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it optional? can we say optional else value. What happens by default?

(no in case expression since it's clear from the context)


Returns:
Func: A Func object that represents the case function.
Expand All @@ -111,45 +114,84 @@ def case(*args: tuple[BinaryExpression, CaseT], else_=None) -> Func:
"""
supported_types = [int, float, complex, str, bool]

type_ = type(else_) if else_ else None
def _get_type(val):
if isinstance(val, Func):
# nested functions
return val.result_type
return type(val)

if not args:
raise DataChainParamsError("Missing statements")

type_ = _get_type(else_) if else_ is not None else None

for arg in args:
if type_ and not isinstance(arg[1], type_):
raise DataChainParamsError("Statement values must be of the same type")
type_ = type(arg[1])
arg_type = _get_type(arg[1])
if type_ and arg_type != type_:
raise DataChainParamsError(
f"Statement values must be of the same type, got {type_} and {arg_type}"
)
type_ = arg_type

if type_ not in supported_types:
raise DataChainParamsError(
f"Only python literals ({supported_types}) are supported for values"
)

kwargs = {"else_": else_}
return Func("case", inner=sql_case, args=args, kwargs=kwargs, result_type=type_)

return Func("case", inner=sql_case, cols=args, kwargs=kwargs, result_type=type_)


def ifelse(condition: BinaryExpression, if_val: CaseT, else_val: CaseT) -> Func:
def ifelse(
condition: Union[ColumnElement, Func], if_val: CaseT, else_val: CaseT
) -> Func:
"""
Returns the ifelse function that produces if expression which has a condition
and values for true and false outcome. Results can only be python primitives
like string, numbes or booleans. Result type is inferred from the values.
and values for true and false outcome. Results can be one of python primitives
like string, numbes or booleans, but can also be nested functions.
Result type is inferred from the values.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numbes -> numbers


Args:
condition: BinaryExpression - condition which is evaluated
if_val: (str | int | float | complex | bool): value for true condition outcome
else_val: (str | int | float | complex | bool): value for false condition
outcome
condition (ColumnElement, Func): Condition which is evaluated.
if_val (str | int | float | complex | bool, Func): Value for true
condition outcome.
else_val (str | int | float | complex | bool, Func): Value for false condition
outcome.

Returns:
Func: A Func object that represents the ifelse function.

Example:
```py
dc.mutate(
res=func.ifelse(C("num") > 0, "P", "N"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add another example with a column expression?

also an examples with a results as an expression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need that many examples? All other functions have exactly one example even though there can be more as well like the version with column expression as you stated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to my mind yes, examples is the most valuable part usually since you can get an idea of what is actually possible. We need more simple examples everywhere.

res=func.ifelse(isnone("col"), "EMPTY", "NOT_EMPTY"),
)
```
"""
return case((condition, if_val), else_=else_val)


def isnone(col: Union[str, Column]) -> Func:
"""
Returns True if column value is None, otherwise False

Args:
col (str | Column): Column to check if it's None or not.
If a string is provided, it is assumed to be the name of the column.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consistency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to figure out but not sure what is wrong in this part of docs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.e. in some cases I saw it starts with lower case, sometimes I saw - before the descriptions, sometimes there was no period at the end, sometimes it's a new line, etc

and all of this within this single PR

and it also applies to the description, e.g. we don't have period here - is it always the case?


Returns:
Func: A Func object that represents the conditional to check if column is None.

Example:
```py
dc.mutate(test=ifelse(isnone("col"), "EMPTY", "NOT_EMPTY"))
```
"""
from datachain import C

if isinstance(col, str):
# if string, it is assumed to be the name of the column
col = C(col)

return case((col == None, True), else_=False) # noqa: E711
19 changes: 13 additions & 6 deletions src/datachain/func/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .window import Window


ColT = Union[str, ColumnElement, "Func"]
ColT = Union[str, ColumnElement, "Func", tuple]


class Func(Function):
Expand Down Expand Up @@ -78,7 +78,7 @@ def _db_cols(self) -> Sequence[ColT]:
return (
[
col
if isinstance(col, (Func, BindParameter, Case, Comparator))
if isinstance(col, (Func, BindParameter, Case, Comparator, tuple))
else ColumnMeta.to_db_name(
col.name if isinstance(col, ColumnElement) else col
)
Expand Down Expand Up @@ -381,17 +381,24 @@ def get_column(
col_type = self.get_result_type(signals_schema)
sql_type = python_to_sql(col_type)

def get_col(col: ColT) -> ColT:
def get_col(col: ColT, string_as_literal=False) -> ColT:
# string_as_literal is used only for conditionals like `case()` where
# literals are nested inside ColT as we have tuples of condition - values
# and if user wants to set some case value as column, explicit `C("col")`
# syntax must be used to distinguish from literals
if isinstance(col, tuple):
return tuple(get_col(x, string_as_literal=True) for x in col)
if isinstance(col, Func):
return col.get_column(signals_schema, table=table)
if isinstance(col, str):
if isinstance(col, str) and not string_as_literal:
column = Column(col, sql_type)
column.table = table
return column
return col

cols = [get_col(col) for col in self._db_cols]
func_col = self.inner(*cols, *self.args, **self.kwargs)
kwargs = {k: get_col(v, string_as_literal=True) for k, v in self.kwargs.items()}
func_col = self.inner(*cols, *self.args, **kwargs)

if self.is_window:
if not self.window:
Expand Down Expand Up @@ -423,7 +430,7 @@ def get_db_col_type(signals_schema: "SignalSchema", col: ColT) -> "DataType":
return sql_to_python(col)

return signals_schema.get_column_type(
col.name if isinstance(col, ColumnElement) else col
col.name if isinstance(col, ColumnElement) else col # type: ignore[arg-type]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have this ignore here?

)


Expand Down
19 changes: 18 additions & 1 deletion tests/unit/sql/test_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def test_case_not_same_result_types(warehouse):
val = 2
with pytest.raises(DataChainParamsError) as exc_info:
select(func.case(*[(val > 1, "A"), (2 < val < 4, 5)], else_="D"))
assert str(exc_info.value) == "Statement values must be of the same type"
assert str(exc_info.value) == (
"Statement values must be of the same type, got <class 'str'> amd <class 'int'>"
)


def test_case_wrong_result_type(warehouse):
Expand Down Expand Up @@ -124,3 +126,18 @@ def test_ifelse(warehouse, val, expected):
query = select(func.ifelse(val <= 3, "L", "H"))
result = tuple(warehouse.db.execute(query))
assert result == ((expected,),)


@pytest.mark.parametrize(
"val,expected",
[
[None, True],
[func.literal("abcd"), False],
],
)
def test_isnone(warehouse, val, expected):
from datachain.func.conditional import isnone

query = select(isnone(val))
result = tuple(warehouse.db.execute(query))
assert result == ((expected,),)
82 changes: 82 additions & 0 deletions tests/unit/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
case,
ifelse,
int_hash_64,
isnone,
literal,
)
from datachain.func.random import rand
Expand All @@ -18,6 +19,7 @@
sqlite_byte_hamming_distance,
sqlite_int_hash_64,
)
from tests.utils import skip_if_not_sqlite


@pytest.fixture()
Expand Down Expand Up @@ -663,6 +665,59 @@ def test_case_mutate(dc, val, else_, type_):
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"val,else_,type_",
[
["A", "D", str],
[1, 2, int],
[1.5, 2.5, float],
[True, False, bool],
],
)
def test_nested_case_on_condition_mutate(dc, val, else_, type_):
res = dc.mutate(
test=case((case((C("num") < 2, True), else_=False), val), else_=else_)
)
assert list(res.order_by("test").collect("test")) == sorted(
[val, else_, else_, else_, else_]
)
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"v1,v2,v3,type_",
[
["A", "B", "C", str],
[1, 2, 3, int],
[1.5, 2.5, 3.5, float],
[False, True, True, bool],
],
)
def test_nested_case_on_value_mutate(dc, v1, v2, v3, type_):
res = dc.mutate(
test=case((C("num") < 4, case((C("num") < 2, v1), else_=v2)), else_=v3)
)
assert list(res.order_by("num").collect("test")) == sorted([v1, v2, v2, v3, v3])
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"v1,v2,v3,type_",
[
["A", "B", "C", str],
[1, 2, 3, int],
[1.5, 2.5, 3.5, float],
[False, True, True, bool],
],
)
def test_nested_case_on_else_mutate(dc, v1, v2, v3, type_):
res = dc.mutate(
test=case((C("num") < 3, v1), else_=case((C("num") < 4, v2), else_=v3))
)
assert list(res.order_by("num").collect("test")) == sorted([v1, v1, v2, v3, v3])
assert res.schema["test"] == type_


@pytest.mark.parametrize(
"if_val,else_val,type_",
[
Expand All @@ -678,3 +733,30 @@ def test_ifelse_mutate(dc, if_val, else_val, type_):
[if_val, else_val, else_val, else_val, else_val]
)
assert res.schema["test"] == type_


@pytest.mark.parametrize("col", ["val", C("val")])
def test_isnone_mutate(col):
dc = DataChain.from_values(
num=list(range(1, 6)),
val=[None if i > 3 else "A" for i in range(1, 6)],
)

res = dc.mutate(test=isnone(col))
assert list(res.order_by("test").collect("test")) == sorted(
[False, False, False, True, True]
)
assert res.schema["test"] is bool


@pytest.mark.parametrize("col", [C("val"), "val"])
@skip_if_not_sqlite
def test_isnone_with_ifelse_mutate(col):
dc = DataChain.from_values(
shcheklein marked this conversation as resolved.
Show resolved Hide resolved
num=list(range(1, 6)),
val=[None if i > 3 else "A" for i in range(1, 6)],
)

res = dc.mutate(test=ifelse(isnone(col), "NONE", "NOT_NONE"))
assert list(res.order_by("num").collect("test")) == ["NOT_NONE"] * 3 + ["NONE"] * 2
assert res.schema["test"] is str
Loading