Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing SQLAlchemy from DataChain.compare() #881

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
206 changes: 79 additions & 127 deletions src/datachain/diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import sqlalchemy as sa

from datachain.func import case, ifelse, isnone
from datachain.lib.signal_schema import SignalSchema
from datachain.query.schema import Column
from datachain.sql.types import String

if TYPE_CHECKING:
from datachain.lib.dc import DataChain
Expand All @@ -32,7 +32,7 @@ class CompareStatus(str, Enum):
SAME = "S"


def _compare( # noqa: PLR0912, PLR0915, C901
def _compare( # noqa: C901
left: "DataChain",
right: "DataChain",
on: Union[str, Sequence[str]],
Expand All @@ -44,66 +44,49 @@ def _compare( # noqa: PLR0912, PLR0915, C901
modified: bool = True,
same: bool = True,
status_col: Optional[str] = None,
) -> "DataChain":
):
"""Comparing two chains by identifying rows that are added, deleted, modified
or same"""
dialect = left._query.dialect

rname = "right_"
schema = left.signals_schema # final chain must have schema from left chain

def _rprefix(c: str, rc: str) -> str:
"""Returns prefix of right of two companion left - right columns
from merge. If companion columns have the same name then prefix will
be present in right column name, otherwise it won't.
"""
return rname if c == rc else ""

def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
def _to_list(obj: Optional[Union[str, Sequence[str]]]) -> Optional[list[str]]:
if obj is None:
return None
return [obj] if isinstance(obj, str) else list(obj)

if on is None:
raise ValueError("'on' must be specified")

on = _to_list(on)
if right_on:
right_on = _to_list(right_on)
if len(on) != len(right_on):
raise ValueError("'on' and 'right_on' must be have the same length")

if compare:
compare = _to_list(compare)

if right_compare:
if not compare:
raise ValueError("'compare' must be defined if 'right_compare' is defined")

right_compare = _to_list(right_compare)
if len(compare) != len(right_compare):
raise ValueError(
"'compare' and 'right_compare' must be have the same length"
)
on = _to_list(on) # type: ignore[assignment]
right_on = _to_list(right_on)
compare = _to_list(compare)
right_compare = _to_list(right_compare)

if not any([added, deleted, modified, same]):
raise ValueError(
"At least one of added, deleted, modified, same flags must be set"
)

need_status_col = bool(status_col)
# we still need status column for internal implementation even if not
# needed in the output
status_col = status_col or get_status_col_name()

# calculate on and compare column names
right_on = right_on or on
if on is None:
raise ValueError("'on' must be specified")
if right_on and len(on) != len(right_on):
raise ValueError("'on' and 'right_on' must be have the same length")
if right_compare and not compare:
raise ValueError("'compare' must be defined if 'right_compare' is defined")
if compare and right_compare and len(compare) != len(right_compare):
raise ValueError("'compare' and 'right_compare' must have the same length")

# all left and right columns
cols = left.signals_schema.clone_without_sys_signals().db_signals()
right_cols = right.signals_schema.clone_without_sys_signals().db_signals()

# getting correct on and right_on column names
on = left.signals_schema.resolve(*on).db_signals() # type: ignore[assignment]
right_on = right.signals_schema.resolve(*right_on).db_signals() # type: ignore[assignment]
right_on = right.signals_schema.resolve(*(right_on or on)).db_signals() # type: ignore[assignment]

# getting correct compare and right_compare column names if they are defined
if compare:
right_compare = right_compare or compare
compare = left.signals_schema.resolve(*compare).db_signals() # type: ignore[assignment]
right_compare = right.signals_schema.resolve(*right_compare).db_signals() # type: ignore[assignment]
right_compare = right.signals_schema.resolve(
*(right_compare or compare)
).db_signals() # type: ignore[assignment]
elif not compare and len(cols) != len(right_cols):
# here we will mark all rows that are not added or deleted as modified since
# there was no explicit list of compare columns provided (meaning we need
Expand All @@ -113,103 +96,72 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
compare = None
right_compare = None
else:
compare = [c for c in cols if c in right_cols] # type: ignore[misc, assignment]
right_compare = compare
# we are checking all columns as explicit compare is not defined
compare = right_compare = [c for c in cols if c in right_cols and c not in on] # type: ignore[misc]

diff_cond = []
# get diff column names
diff_col = status_col or get_status_col_name()
ldiff_col = get_status_col_name()
rdiff_col = get_status_col_name()

if added:
added_cond = sa.and_(
*[
C(c) == None # noqa: E711
for c in [f"{_rprefix(c, rc)}{rc}" for c, rc in zip(on, right_on)]
]
)
diff_cond.append((added_cond, CompareStatus.ADDED))
if modified and compare:
modified_cond = sa.or_(
*[
C(c) != C(f"{_rprefix(c, rc)}{rc}")
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
]
)
diff_cond.append((modified_cond, CompareStatus.MODIFIED))
if same and compare:
same_cond = sa.and_(
# adding helper diff columns, which will be removed after
left = left.mutate(**{ldiff_col: 1})
right = right.mutate(**{rdiff_col: 1})

if not compare:
modified_cond = True
else:
modified_cond = sa.or_( # type: ignore[assignment]
*[
C(c) == C(f"{_rprefix(c, rc)}{rc}")
C(c) != (C(f"{rname}{rc}") if c == rc else C(rc))
for c, rc in zip(compare, right_compare) # type: ignore[arg-type]
]
)
diff_cond.append((same_cond, CompareStatus.SAME))

diff = sa.case(*diff_cond, else_=None if compare else CompareStatus.MODIFIED).label(
status_col
)
diff.type = String()

left_right_merge = left.merge(
right, on=on, right_on=right_on, inner=False, rname=rname
)
left_right_merge_select = left_right_merge._query.select(
*(
[C(c) for c in left_right_merge.signals_schema.db_signals("sys")]
+ [C(c) for c in on]
+ [C(c) for c in cols if c not in on]
+ [diff]
dc_diff = (
left.merge(right, on=on, right_on=right_on, rname=rname, full=True)
.mutate(
**{
diff_col: case(
(isnone(ldiff_col), CompareStatus.DELETED),
(isnone(rdiff_col), CompareStatus.ADDED),
(modified_cond, CompareStatus.MODIFIED),
else_=CompareStatus.SAME,
)
}
)
)

diff_col = sa.literal(CompareStatus.DELETED).label(status_col)
diff_col.type = String()

right_left_merge = right.merge(
left, on=right_on, right_on=on, inner=False, rname=rname
).filter(
sa.and_(
*[C(f"{_rprefix(c, rc)}{c}") == None for c, rc in zip(on, right_on)] # noqa: E711
)
)

def _default_val(chain: "DataChain", col: str):
col_type = chain._query.column_types[col] # type: ignore[index]
val = sa.literal(col_type.default_value(dialect)).label(col)
val.type = col_type()
return val

right_left_merge_select = right_left_merge._query.select(
*(
[C(c) for c in right_left_merge.signals_schema.db_signals("sys")]
+ [
C(c) if c == rc else _default_val(left, c)
for c, rc in zip(on, right_on)
]
+ [
C(c) if c in right_cols else _default_val(left, c) # type: ignore[arg-type]
for c in cols
if c not in on
]
+ [diff_col]
# when the row is deleted, we need to take column values from the right chain
.mutate(
**{
f"{c}": ifelse(
C(diff_col) == CompareStatus.DELETED, C(f"{rname}{c}"), C(c)
)
for c in [c for c in cols if c in right_cols]
}
)
.select_except(ldiff_col, rdiff_col)
)

if not added:
dc_diff = dc_diff.filter(C(diff_col) != CompareStatus.ADDED)
if not modified:
dc_diff = dc_diff.filter(C(diff_col) != CompareStatus.MODIFIED)
if not same:
dc_diff = dc_diff.filter(C(diff_col) != CompareStatus.SAME)
if not deleted:
res = left_right_merge_select
elif deleted and not any([added, modified, same]):
res = right_left_merge_select
else:
res = left_right_merge_select.union(right_left_merge_select)
dc_diff = dc_diff.filter(C(diff_col) != CompareStatus.DELETED)

res = res.filter(C(status_col) != None) # noqa: E711
if status_col:
cols.append(diff_col) # type: ignore[arg-type]

schema = left.signals_schema
if need_status_col:
res = res.select()
schema = SignalSchema({status_col: str}) | schema
else:
res = res.select_except(C(status_col))
dc_diff = dc_diff.select(*cols)

# final schema is schema from the left chain with status column added if needed
dc_diff.signals_schema = (
schema if not status_col else SignalSchema({status_col: str}) | schema
)

return left._evolve(query=res, signal_schema=schema)
return dc_diff


def compare_and_split(
Expand Down
8 changes: 6 additions & 2 deletions src/datachain/func/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def least(*args: Union[ColT, float]) -> Func:


def case(
*args: tuple[Union[ColumnElement, Func], CaseT], else_: Optional[CaseT] = None
*args: tuple[Union[ColumnElement, Func, bool], CaseT], else_: Optional[CaseT] = None
) -> Func:
"""
Returns the case function that produces case expression which has a list of
Expand All @@ -99,7 +99,7 @@ def case(
Result type is inferred from condition results.

Args:
args tuple((ColumnElement | Func),(str | int | float | complex | bool, Func, ColumnElement)):
args tuple((ColumnElement | Func | bool),(str | int | float | complex | bool, Func, ColumnElement)):
Tuple of condition and values pair.
else_ (str | int | float | complex | bool, Func): optional else value in case
expression. If omitted, and no case conditions are satisfied, the result
Expand All @@ -118,12 +118,16 @@ def case(
supported_types = [int, float, complex, str, bool]

def _get_type(val):
from enum import Enum

if isinstance(val, Func):
# nested functions
return val.result_type
if isinstance(val, Column):
# at this point we cannot know what is the type of a column
return None
if isinstance(val, Enum):
return type(val.value)
return type(val)

if not args:
Expand Down
21 changes: 16 additions & 5 deletions tests/unit/lib/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timezone

import pytest
from pydantic import BaseModel

Expand All @@ -8,6 +10,10 @@
from tests.utils import sorted_dicts


def _as_utc(d):
return d.replace(tzinfo=timezone.utc)


@pytest.mark.parametrize("added", (True, False))
@pytest.mark.parametrize("deleted", (True, False))
@pytest.mark.parametrize("modified", (True, False))
Expand Down Expand Up @@ -402,7 +408,7 @@ def test_compare_right_compare_wrong_length(test_session):
ds1.compare(ds2, on=["id"], compare=["name"], right_compare=["name", "city"])

assert str(exc_info.value) == (
"'compare' and 'right_compare' must be have the same length"
"'compare' and 'right_compare' must have the same length"
)


Expand Down Expand Up @@ -443,7 +449,11 @@ def test_diff(test_session, status_col):
expected = [row[1:] for row in expected]
collect_fields = collect_fields[1:]

assert list(diff.order_by("file.source").collect(*collect_fields)) == expected
res = list(diff.order_by("file.source").collect(*collect_fields))
for r in res:
r[-2].last_modified = _as_utc(r[-2].last_modified)

assert res == expected


@pytest.mark.parametrize("status_col", ("diff", None))
Expand Down Expand Up @@ -486,6 +496,7 @@ class Nested(BaseModel):
expected = [row[1:] for row in expected]
collect_fields = collect_fields[1:]

assert (
list(diff.order_by("nested.file.source").collect(*collect_fields)) == expected
)
res = list(diff.order_by("nested.file.source").collect(*collect_fields))
for r in res:
r[-2].file.last_modified = _as_utc(r[-2].file.last_modified)
assert res == expected