Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow merge on expressions #388

Merged
merged 7 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions examples/multimodal/clip_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
from torch.utils.data import DataLoader

from datachain import C, DataChain
from datachain.sql.functions import path

source = "gs://datachain-demo/50k-laion-files/000000/00000000*"


def create_dataset():
imgs = (
DataChain.from_storage(source, type="image")
.filter(C("file.path").glob("*.jpg"))
.map(stem=lambda file: file.get_file_stem(), params=["file"], output=str)
imgs = DataChain.from_storage(source, type="image").filter(
C("file.path").glob("*.jpg")
)
captions = (
DataChain.from_storage(source, type="text")
.filter(C("file.path").glob("*.txt"))
.map(stem=lambda file: file.get_file_stem(), params=["file"], output=str)
captions = DataChain.from_storage(source, type="text").filter(
C("file.path").glob("*.txt")
)
return imgs.merge(
captions,
on=path.file_stem(imgs.c("file.path")),
right_on=path.file_stem(captions.c("file.path")),
)
return imgs.merge(captions, on="stem")


if __name__ == "__main__":
Expand Down
21 changes: 10 additions & 11 deletions examples/multimodal/wds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

from datachain import C, DataChain
from datachain import DataChain
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion, process_laion_meta
from datachain.sql.functions import path
Expand All @@ -25,21 +25,20 @@
DataChain.from_parquet(PARQUET_METADATA)
.settings(cache=True)
.merge(wds_images, on="uid", right_on="laion.json.uid", inner=True)
.mutate(stem=path.file_stem(C("source.file.path")))
)

res = (
wds_npz = (
DataChain.from_storage(NPZ_METADATA)
.settings(cache=True)
.gen(emd=process_laion_meta)
.mutate(stem=path.file_stem(C("emd.file.path")))
.merge(
wds_with_pq,
on=["stem", "emd.index"],
right_on=["stem", "source.index"],
inner=True,
)
.save("wds")
)


res = wds_npz.merge(
wds_with_pq,
on=[path.file_stem(wds_npz.c("emd.file.path")), "emd.index"],
right_on=[path.file_stem(wds_with_pq.c("source.file.path")), "source.index"],
inner=True,
).save("wds")

res.show(5)
107 changes: 74 additions & 33 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
PartitionByType,
detach,
)
from datachain.query.schema import Column, DatasetRow
from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow
from datachain.sql.functions import path as pathfunc
from datachain.utils import inside_notebook

Expand Down Expand Up @@ -112,11 +112,31 @@
super().__init__(f"Dataset{name} from values error: {msg}")


def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str:
if isinstance(col, str):
return col
if isinstance(col, sqlalchemy.Column):
return col.name.replace(DEFAULT_DELIMITER, ".")

Check warning on line 119 in src/datachain/lib/dc.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/dc.py#L119

Added line #L119 was not covered by tests
if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"):
return f"{col.name} expression"
return str(col)

Check warning on line 122 in src/datachain/lib/dc.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/dc.py#L122

Added line #L122 was not covered by tests


class DatasetMergeError(DataChainParamsError): # noqa: D101
def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: str): # noqa: D107
on_str = ", ".join(on) if isinstance(on, Sequence) else ""
def __init__( # noqa: D107
self,
on: Sequence[Union[str, sqlalchemy.ColumnElement]],
right_on: Optional[Sequence[Union[str, sqlalchemy.ColumnElement]]],
msg: str,
):
def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str:
if not isinstance(on, Sequence):
return str(on) # type: ignore[unreachable]
return ", ".join([_get_merge_error_str(col) for col in on])

on_str = _get_str(on)
right_on_str = (
", right_on='" + ", ".join(right_on) + "'"
", right_on='" + _get_str(right_on) + "'"
if right_on and isinstance(right_on, Sequence)
else ""
)
Expand Down Expand Up @@ -252,13 +272,24 @@
"""Returns Column instance with a type if name is found in current schema,
otherwise raises an exception.
"""
name_path = name.split(".")
if "." in name:
name_path = name.split(".")
elif DEFAULT_DELIMITER in name:
name_path = name.split(DEFAULT_DELIMITER)
Comment on lines +275 to +278
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wonder, if it is possible to have both dot (.) and DEFAULT_DELIMITER (__) in column name? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. Seems like one is always swapped for the other.

else:
name_path = [name]
for path, type_, _, _ in self.signals_schema.get_flat_tree():
if path == name_path:
return Column(name, python_to_sql(type_))

raise ValueError(f"Column with name {name} not found in the schema")

def c(self, column: Union[str, Column]) -> Column:
"""Returns Column instance attached to the current chain."""
c = self.column(column) if isinstance(column, str) else self.column(column.name)
c.table = self.table
return c

def print_schema(self) -> None:
"""Print schema of the chain."""
self._effective_signals_schema.print_tree()
Expand Down Expand Up @@ -1140,8 +1171,17 @@
def merge(
self,
right_ds: "DataChain",
on: Union[str, Sequence[str]],
right_on: Union[str, Sequence[str], None] = None,
on: Union[
str,
sqlalchemy.ColumnElement,
Sequence[Union[str, sqlalchemy.ColumnElement]],
],
right_on: Union[
str,
sqlalchemy.ColumnElement,
Sequence[Union[str, sqlalchemy.ColumnElement]],
None,
] = None,
inner=False,
rname="right_",
) -> "Self":
Expand All @@ -1166,7 +1206,7 @@
if on is None:
raise DatasetMergeError(["None"], None, "'on' must be specified")

if isinstance(on, str):
if isinstance(on, (str, sqlalchemy.ColumnElement)):
on = [on]
elif not isinstance(on, Sequence):
raise DatasetMergeError(
Expand All @@ -1175,54 +1215,55 @@
f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'",
)

signals_schema = self.signals_schema.clone_without_sys_signals()
on_columns: list[str] = signals_schema.resolve(*on).db_signals() # type: ignore[assignment]

right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
if right_on is not None:
if isinstance(right_on, str):
if isinstance(right_on, (str, sqlalchemy.ColumnElement)):
right_on = [right_on]
elif not isinstance(right_on, Sequence):
raise DatasetMergeError(
on,
right_on,
"'right_on' must be 'str' or 'Sequence' object"
f" but got type '{right_on}'",
f" but got type '{type(right_on)}'",
)

if len(right_on) != len(on):
raise DatasetMergeError(
on, right_on, "'on' and 'right_on' must have the same length'"
)

right_on_columns: list[str] = right_signals_schema.resolve(
*right_on
).db_signals() # type: ignore[assignment]

if len(right_on_columns) != len(on_columns):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we loosing this check after these changes or I miss something? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check is moved down. If any of the columns fail to resolve then we will have an entry in errors and we'll raise a DatasetMergeError.

on_str = ", ".join(right_on_columns)
right_on_str = ", ".join(right_on_columns)
raise DatasetMergeError(
on,
right_on,
"'on' and 'right_on' must have the same number of columns in db'."
f" on -> {on_str}, right_on -> {right_on_str}",
)
else:
right_on = on
right_on_columns = on_columns

if self == right_ds:
right_ds = right_ds.clone(new_table=True)

errors = []

def _resolve(
ds: DataChain,
col: Union[str, sqlalchemy.ColumnElement],
side: Union[str, None],
):
try:
return ds.c(col) if isinstance(col, (str, C)) else col
except ValueError:
if side:
errors.append(f"{_get_merge_error_str(col)} in {side}")

ops = [
self.c(left) == right_ds.c(right)
for left, right in zip(on_columns, right_on_columns)
_resolve(self, left, "left")
== _resolve(right_ds, right, "right" if right_on else None)
for left, right in zip(on, right_on or on)
]

if errors:
raise DatasetMergeError(
on, right_on, f"Could not resolve {', '.join(errors)}"
)

ds = self.join(right_ds, sqlalchemy.and_(*ops), inner, rname + "{name}")

ds.feature_schema = None

signals_schema = self.signals_schema.clone_without_sys_signals()
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals()
ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge(
right_signals_schema, rname
)
Expand Down
8 changes: 6 additions & 2 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,8 +1168,12 @@ def attached(self) -> bool:
"""
return self.name is not None and self.version is not None

def c(self, name: Union[C, str]) -> "ColumnClause[Any]":
col = sqlalchemy.column(name) if isinstance(name, str) else name
def c(self, column: Union[C, str]) -> "ColumnClause[Any]":
col: sqlalchemy.ColumnClause = (
sqlalchemy.column(column)
if isinstance(column, str)
else sqlalchemy.column(column.name, column.type)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[F] It does not seem possible to overwrite the table of a column which already has the table set. I.e test_merge_with_itself_column fails without this change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is no longer required for this PR to work. I'd be happy to split it out if required.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for me personally I am OK to leave it here. It feels natural to have this change here.

)
col.table = self.table
return col

Expand Down
61 changes: 57 additions & 4 deletions tests/unit/lib/test_datachain_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import pandas as pd
import pytest
from pydantic import BaseModel
from sqlalchemy import func

from datachain.lib.dc import DataChain, DatasetMergeError
from datachain.lib.signal_schema import SignalResolvingError
from datachain.lib.dc import C, DataChain, DatasetMergeError
from datachain.sql.types import String
from tests.utils import skip_if_not_sqlite

Expand Down Expand Up @@ -196,7 +196,7 @@ def test_merge_multi_conditions(test_session):
id=delivery_ids, d_name=delivery_name, time=delivery_time, session=test_session
)

ch = ch1.merge(ch2, ("id", "name"), ("id", "d_name"))
ch = ch1.merge(ch2, ("id", "name"), ("id", C("d_name")))

res = list(ch.collect())

Expand All @@ -213,11 +213,23 @@ def test_merge_errors(test_session):
ch1 = DataChain.from_values(emp=employees, session=test_session)
ch2 = DataChain.from_values(team=team, session=test_session)

with pytest.raises(SignalResolvingError):
with pytest.raises(DatasetMergeError):
ch1.merge(ch2, "unknown")

with pytest.raises(DatasetMergeError):
ch1.merge(ch2, ["emp.person.name"], "unknown")

with pytest.raises(DatasetMergeError):
ch1.merge(ch2, ["emp.person.name"], ["unknown"])

with pytest.raises(DatasetMergeError):
ch1.merge(
ch2, ("emp.person.age", func.substr(["emp.person.name"], 2)), "unknown"
)

ch1.merge(ch2, ["emp.person.name"], ["team.sport"])
ch1.merge(ch2, ["emp.person.name"], ["team.sport"])

with pytest.raises(DatasetMergeError):
ch1.merge(ch2, ["emp.person.name"], ["team.player", "team.sport"])

Expand All @@ -240,3 +252,44 @@ def test_merge_with_itself(test_session):
count += 1

assert count == len(employees)


def test_merge_with_itself_column(test_session):
ch = DataChain.from_values(emp=employees, session=test_session)
merged = ch.merge(ch, C("emp.id"))

count = 0
for left, right in merged.collect():
assert isinstance(left, Employee)
assert isinstance(right, Employee)
assert left == right == employees[count]
count += 1

assert count == len(employees)


def test_merge_on_expression(test_session):
def _get_expr(dc):
c = dc.c("team.sport")
return func.substr(c, func.length(c) - 3)

dc = DataChain.from_values(team=team, session=test_session)
right_dc = dc.clone(new_table=True)

# cross join on "ball" from sport
merged = dc.merge(right_dc, on=_get_expr(dc), right_on=_get_expr(right_dc))

cross_team = [
(left_member, right_member) for left_member in team for right_member in team
]

count = 0
for left, right_dc in merged.collect():
assert isinstance(left, TeamMember)
assert isinstance(right_dc, TeamMember)
left_member, right_member = cross_team[count]
assert left == left_member
assert right_dc == right_member
count += 1

assert count == len(team) * len(team)
Loading