-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
allow merge on expressions #388
Changes from all commits
2a460ab
5652446
3e67154
6ad8c34
14d72b5
d4c39b2
04a425b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,7 @@ | |
PartitionByType, | ||
detach, | ||
) | ||
from datachain.query.schema import Column, DatasetRow | ||
from datachain.query.schema import DEFAULT_DELIMITER, Column, DatasetRow | ||
from datachain.sql.functions import path as pathfunc | ||
from datachain.utils import inside_notebook | ||
|
||
|
@@ -112,11 +112,31 @@ | |
super().__init__(f"Dataset{name} from values error: {msg}") | ||
|
||
|
||
def _get_merge_error_str(col: Union[str, sqlalchemy.ColumnElement]) -> str: | ||
if isinstance(col, str): | ||
return col | ||
if isinstance(col, sqlalchemy.Column): | ||
return col.name.replace(DEFAULT_DELIMITER, ".") | ||
if isinstance(col, sqlalchemy.ColumnElement) and hasattr(col, "name"): | ||
return f"{col.name} expression" | ||
return str(col) | ||
|
||
|
||
class DatasetMergeError(DataChainParamsError): # noqa: D101 | ||
def __init__(self, on: Sequence[str], right_on: Optional[Sequence[str]], msg: str): # noqa: D107 | ||
on_str = ", ".join(on) if isinstance(on, Sequence) else "" | ||
def __init__( # noqa: D107 | ||
self, | ||
on: Sequence[Union[str, sqlalchemy.ColumnElement]], | ||
right_on: Optional[Sequence[Union[str, sqlalchemy.ColumnElement]]], | ||
msg: str, | ||
): | ||
def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str: | ||
if not isinstance(on, Sequence): | ||
return str(on) # type: ignore[unreachable] | ||
return ", ".join([_get_merge_error_str(col) for col in on]) | ||
|
||
on_str = _get_str(on) | ||
right_on_str = ( | ||
", right_on='" + ", ".join(right_on) + "'" | ||
", right_on='" + _get_str(right_on) + "'" | ||
if right_on and isinstance(right_on, Sequence) | ||
else "" | ||
) | ||
|
@@ -252,13 +272,24 @@ | |
"""Returns Column instance with a type if name is found in current schema, | ||
otherwise raises an exception. | ||
""" | ||
name_path = name.split(".") | ||
if "." in name: | ||
name_path = name.split(".") | ||
elif DEFAULT_DELIMITER in name: | ||
name_path = name.split(DEFAULT_DELIMITER) | ||
else: | ||
name_path = [name] | ||
for path, type_, _, _ in self.signals_schema.get_flat_tree(): | ||
if path == name_path: | ||
return Column(name, python_to_sql(type_)) | ||
|
||
raise ValueError(f"Column with name {name} not found in the schema") | ||
|
||
def c(self, column: Union[str, Column]) -> Column: | ||
"""Returns Column instance attached to the current chain.""" | ||
c = self.column(column) if isinstance(column, str) else self.column(column.name) | ||
c.table = self.table | ||
return c | ||
|
||
def print_schema(self) -> None: | ||
"""Print schema of the chain.""" | ||
self._effective_signals_schema.print_tree() | ||
|
@@ -1140,8 +1171,17 @@ | |
def merge( | ||
self, | ||
right_ds: "DataChain", | ||
on: Union[str, Sequence[str]], | ||
right_on: Union[str, Sequence[str], None] = None, | ||
on: Union[ | ||
str, | ||
sqlalchemy.ColumnElement, | ||
Sequence[Union[str, sqlalchemy.ColumnElement]], | ||
], | ||
right_on: Union[ | ||
str, | ||
sqlalchemy.ColumnElement, | ||
Sequence[Union[str, sqlalchemy.ColumnElement]], | ||
None, | ||
] = None, | ||
inner=False, | ||
rname="right_", | ||
) -> "Self": | ||
|
@@ -1166,7 +1206,7 @@ | |
if on is None: | ||
raise DatasetMergeError(["None"], None, "'on' must be specified") | ||
|
||
if isinstance(on, str): | ||
if isinstance(on, (str, sqlalchemy.ColumnElement)): | ||
on = [on] | ||
elif not isinstance(on, Sequence): | ||
raise DatasetMergeError( | ||
|
@@ -1175,54 +1215,55 @@ | |
f"'on' must be 'str' or 'Sequence' object but got type '{type(on)}'", | ||
) | ||
|
||
signals_schema = self.signals_schema.clone_without_sys_signals() | ||
on_columns: list[str] = signals_schema.resolve(*on).db_signals() # type: ignore[assignment] | ||
|
||
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals() | ||
if right_on is not None: | ||
if isinstance(right_on, str): | ||
if isinstance(right_on, (str, sqlalchemy.ColumnElement)): | ||
right_on = [right_on] | ||
elif not isinstance(right_on, Sequence): | ||
raise DatasetMergeError( | ||
on, | ||
right_on, | ||
"'right_on' must be 'str' or 'Sequence' object" | ||
f" but got type '{right_on}'", | ||
f" but got type '{type(right_on)}'", | ||
) | ||
|
||
if len(right_on) != len(on): | ||
raise DatasetMergeError( | ||
on, right_on, "'on' and 'right_on' must have the same length'" | ||
) | ||
|
||
right_on_columns: list[str] = right_signals_schema.resolve( | ||
*right_on | ||
).db_signals() # type: ignore[assignment] | ||
|
||
if len(right_on_columns) != len(on_columns): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we loosing this check after these changes or I miss something? 🤔 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check is moved down. If any of the columns fail to resolve then we will have an entry in |
||
on_str = ", ".join(right_on_columns) | ||
right_on_str = ", ".join(right_on_columns) | ||
raise DatasetMergeError( | ||
on, | ||
right_on, | ||
"'on' and 'right_on' must have the same number of columns in db'." | ||
f" on -> {on_str}, right_on -> {right_on_str}", | ||
) | ||
else: | ||
right_on = on | ||
right_on_columns = on_columns | ||
|
||
if self == right_ds: | ||
right_ds = right_ds.clone(new_table=True) | ||
|
||
errors = [] | ||
|
||
def _resolve( | ||
ds: DataChain, | ||
col: Union[str, sqlalchemy.ColumnElement], | ||
side: Union[str, None], | ||
): | ||
try: | ||
return ds.c(col) if isinstance(col, (str, C)) else col | ||
except ValueError: | ||
if side: | ||
errors.append(f"{_get_merge_error_str(col)} in {side}") | ||
|
||
ops = [ | ||
self.c(left) == right_ds.c(right) | ||
for left, right in zip(on_columns, right_on_columns) | ||
_resolve(self, left, "left") | ||
== _resolve(right_ds, right, "right" if right_on else None) | ||
for left, right in zip(on, right_on or on) | ||
] | ||
|
||
if errors: | ||
raise DatasetMergeError( | ||
on, right_on, f"Could not resolve {', '.join(errors)}" | ||
) | ||
|
||
ds = self.join(right_ds, sqlalchemy.and_(*ops), inner, rname + "{name}") | ||
|
||
ds.feature_schema = None | ||
|
||
signals_schema = self.signals_schema.clone_without_sys_signals() | ||
right_signals_schema = right_ds.signals_schema.clone_without_sys_signals() | ||
ds.signals_schema = SignalSchema({"sys": Sys}) | signals_schema.merge( | ||
right_signals_schema, rname | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1168,8 +1168,12 @@ def attached(self) -> bool: | |
""" | ||
return self.name is not None and self.version is not None | ||
|
||
def c(self, name: Union[C, str]) -> "ColumnClause[Any]": | ||
col = sqlalchemy.column(name) if isinstance(name, str) else name | ||
def c(self, column: Union[C, str]) -> "ColumnClause[Any]": | ||
col: sqlalchemy.ColumnClause = ( | ||
sqlalchemy.column(column) | ||
if isinstance(column, str) | ||
else sqlalchemy.column(column.name, column.type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [F] It does not seem possible to overwrite the table of a column which already has the table set. I.e There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is no longer required for this PR to work. I'd be happy to split it out if required. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As for me personally I am OK to leave it here. It feels natural to have this change here. |
||
) | ||
col.table = self.table | ||
return col | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wonder, if it is possible to have both dot (
.
) andDEFAULT_DELIMITER
(__
) in column name? 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so. Seems like one is always swapped for the other.