diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 11b26487a..53e400c60 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -128,14 +128,14 @@ def __init__( # noqa: D107 right_on: Optional[Sequence[Union[str, sqlalchemy.ColumnElement]]], msg: str, ): - def get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str: + def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str: if not isinstance(on, Sequence): return str(on) # type: ignore[unreachable] return ", ".join([_get_merge_error_str(col) for col in on]) - on_str = get_str(on) + on_str = _get_str(on) right_on_str = ( - ", right_on='" + get_str(right_on) + "'" + ", right_on='" + _get_str(right_on) + "'" if right_on and isinstance(right_on, Sequence) else "" ) @@ -283,8 +283,8 @@ def column(self, name: str) -> Column: raise ValueError(f"Column with name {name} not found in the schema") - def c(self, column: Union[str, Column]) -> "sqlalchemy.ColumnClause": - """Returns ColumnClause instance attached to the current chain.""" + def c(self, column: Union[str, Column]) -> Column: + """Returns Column instance attached to the current chain.""" c = self.column(column) if isinstance(column, str) else self.column(column.name) c.table = self.table return c @@ -1205,7 +1205,7 @@ def merge( errors = [] - def resolve( + def _resolve( ds: DataChain, col: Union[str, sqlalchemy.ColumnElement], side: Union[str, None], @@ -1217,8 +1217,8 @@ def resolve( errors.append(f"{_get_merge_error_str(col)} in {side}") ops = [ - resolve(self, left, "left") - == resolve(right_ds, right, "right" if right_on else None) + _resolve(self, left, "left") + == _resolve(right_ds, right, "right" if right_on else None) for left, right in zip(on, right_on or on) ] diff --git a/tests/unit/lib/test_datachain_merge.py b/tests/unit/lib/test_datachain_merge.py index 1125c8b87..5612658ac 100644 --- a/tests/unit/lib/test_datachain_merge.py +++ b/tests/unit/lib/test_datachain_merge.py @@ -269,21 +269,27 @@ def test_merge_with_itself_column(test_session): def test_merge_on_expression(test_session): - ch = DataChain.from_values(team=team, session=test_session) + def _get_expr(dc): + c = dc.c("team.sport") + return func.substr(c, func.length(c) - 3) + + dc = DataChain.from_values(team=team, session=test_session) + right_dc = dc.clone(new_table=True) + # cross join on "ball" from sport - c = ch.c("team.sport") - expr = func.substr(c, func.length(c) - 4) - merged = ch.merge(ch, on=expr) + merged = dc.merge(right_dc, on=_get_expr(dc), right_on=_get_expr(right_dc)) - cross_team = [(l_member, r_member) for l_member in team for r_member in team] + cross_team = [ + (left_member, right_member) for left_member in team for right_member in team + ] count = 0 - for left, right in merged.collect(): + for left, right_dc in merged.collect(): assert isinstance(left, TeamMember) - assert isinstance(right, TeamMember) - l_member, r_member = cross_team[count] - assert left == l_member - assert right == r_member + assert isinstance(right_dc, TeamMember) + left_member, right_member = cross_team[count] + assert left == left_member + assert right_dc == right_member count += 1 assert count == len(team) * len(team)