Skip to content

Commit

Permalink
fixup: test merge on expression
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon committed Sep 5, 2024
1 parent 61071f7 commit 5b86816
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
16 changes: 8 additions & 8 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ def __init__( # noqa: D107
right_on: Optional[Sequence[Union[str, sqlalchemy.ColumnElement]]],
msg: str,
):
def get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str:
def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str:
if not isinstance(on, Sequence):
return str(on) # type: ignore[unreachable]
return ", ".join([_get_merge_error_str(col) for col in on])

on_str = get_str(on)
on_str = _get_str(on)
right_on_str = (
", right_on='" + get_str(right_on) + "'"
", right_on='" + _get_str(right_on) + "'"
if right_on and isinstance(right_on, Sequence)
else ""
)
Expand Down Expand Up @@ -283,8 +283,8 @@ def column(self, name: str) -> Column:

raise ValueError(f"Column with name {name} not found in the schema")

def c(self, column: Union[str, Column]) -> "sqlalchemy.ColumnClause":
"""Returns ColumnClause instance attached to the current chain."""
def c(self, column: Union[str, Column]) -> Column:
"""Returns Column instance attached to the current chain."""
c = self.column(column) if isinstance(column, str) else self.column(column.name)
c.table = self.table
return c
Expand Down Expand Up @@ -1205,7 +1205,7 @@ def merge(

errors = []

Check warning on line 1207 in src/datachain/lib/dc.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/dc.py#L1207

Added line #L1207 was not covered by tests
def resolve(
def _resolve(
ds: DataChain,
col: Union[str, sqlalchemy.ColumnElement],
side: Union[str, None],
Expand All @@ -1217,8 +1217,8 @@ def resolve(
errors.append(f"{_get_merge_error_str(col)} in {side}")

ops = [
resolve(self, left, "left")
== resolve(right_ds, right, "right" if right_on else None)
_resolve(self, left, "left")
== _resolve(right_ds, right, "right" if right_on else None)
for left, right in zip(on, right_on or on)
]

Expand Down
26 changes: 16 additions & 10 deletions tests/unit/lib/test_datachain_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,21 +269,27 @@ def test_merge_with_itself_column(test_session):


def test_merge_on_expression(test_session):
ch = DataChain.from_values(team=team, session=test_session)
def _get_expr(dc):
c = dc.c("team.sport")
return func.substr(c, func.length(c) - 3)

dc = DataChain.from_values(team=team, session=test_session)
right_dc = dc.clone(new_table=True)

# cross join on "ball" from sport
c = ch.c("team.sport")
expr = func.substr(c, func.length(c) - 4)
merged = ch.merge(ch, on=expr)
merged = dc.merge(right_dc, on=_get_expr(dc), right_on=_get_expr(right_dc))

cross_team = [(l_member, r_member) for l_member in team for r_member in team]
cross_team = [
(left_member, right_member) for left_member in team for right_member in team
]

count = 0
for left, right in merged.collect():
for left, right_dc in merged.collect():
assert isinstance(left, TeamMember)
assert isinstance(right, TeamMember)
l_member, r_member = cross_team[count]
assert left == l_member
assert right == r_member
assert isinstance(right_dc, TeamMember)
left_member, right_member = cross_team[count]
assert left == left_member
assert right_dc == right_member
count += 1

assert count == len(team) * len(team)

0 comments on commit 5b86816

Please sign in to comment.