Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1736729: fix diamond shape join #2871

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
96 changes: 66 additions & 30 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import uuid
from collections import Counter, defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Union
from logging import getLogger

from snowflake.connector import IntegrityError

Expand Down Expand Up @@ -153,7 +154,10 @@
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import TelemetryField
from snowflake.snowpark._internal.utils import quote_name
from snowflake.snowpark._internal.utils import (
quote_name,
merge_multiple_snowflake_plan_expr_to_alias,
)
from snowflake.snowpark.types import BooleanType, _NumericType

ARRAY_BIND_THRESHOLD = 512
Expand All @@ -162,6 +166,9 @@
import snowflake.snowpark.session


_logger = getLogger(__name__)


class Analyzer:
def __init__(self, session: "snowflake.snowpark.session.Session") -> None:
self.session = session
Expand Down Expand Up @@ -368,15 +375,23 @@ def analyze(
return expr.sql

if isinstance(expr, Attribute):
name = self.alias_maps_to_use.get(expr.expr_id, expr.name)
if not self.session._resolve_conflict_alias:
name = self.alias_maps_to_use.get(expr.expr_id, expr.name)
else:
name = self.alias_maps_to_use.get(expr.expr_id, (expr.name, False))[0]
return quote_name(name)

if isinstance(expr, UnresolvedAttribute):
if expr.df_alias:
if expr.df_alias in df_aliased_col_name_to_real_col_name:
return df_aliased_col_name_to_real_col_name[expr.df_alias].get(
expr.name, expr.name
)
if not self.session._resolve_conflict_alias:
return df_aliased_col_name_to_real_col_name[expr.df_alias].get(
expr.name, expr.name
)
else:
return df_aliased_col_name_to_real_col_name[expr.df_alias].get(
expr.name, (expr.name, False)
)[0]
else:
raise SnowparkClientExceptionMessages.DF_ALIAS_NOT_RECOGNIZED(
expr.df_alias
Expand Down Expand Up @@ -406,7 +421,11 @@ def analyze(
expr.df_alias
)
columns = df_aliased_col_name_to_real_col_name[expr.df_alias]
return ",".join(columns.values())
if not self.session._resolve_conflict_alias:
ret = ",".join(columns.values())
else:
ret = ",".join([v[0] for v in columns.values()])
return ",".join(ret)
if not expr.expressions:
return "*"
else:
Expand Down Expand Up @@ -630,16 +649,28 @@ def unary_expression_extractor(
if isinstance(expr, Alias):
quoted_name = quote_name(expr.name)
if isinstance(expr.child, Attribute):
self.generated_alias_maps[expr.child.expr_id] = quoted_name
assert self.alias_maps_to_use is not None
for k, v in self.alias_maps_to_use.items():
if v == expr.child.name:
self.generated_alias_maps[k] = quoted_name

for df_alias_dict in df_aliased_col_name_to_real_col_name.values():
for k, v in df_alias_dict.items():
if not self.session._resolve_conflict_alias:
self.generated_alias_maps[expr.child.expr_id] = quoted_name
assert self.alias_maps_to_use is not None
for k, v in self.alias_maps_to_use.items():
if v == expr.child.name:
df_alias_dict[k] = quoted_name
self.generated_alias_maps[k] = quoted_name

for df_alias_dict in df_aliased_col_name_to_real_col_name.values():
for k, v in df_alias_dict.items():
if v == expr.child.name:
df_alias_dict[k] = quoted_name
else:
self.generated_alias_maps[expr.child.expr_id] = (quoted_name, False)
assert self.alias_maps_to_use is not None
for k, v in self.alias_maps_to_use.items():
if v[0] == expr.child.name:
self.generated_alias_maps[k] = (quoted_name, True)

for df_alias_dict in df_aliased_col_name_to_real_col_name.values():
for k, v in df_alias_dict.items():
if v[0] == expr.child.name:
df_alias_dict[k] = (quoted_name, True)
return alias_expression(
self.analyze(
expr.child, df_aliased_col_name_to_real_col_name, parse_local_name
Expand Down Expand Up @@ -807,22 +838,27 @@ def do_resolve(self, logical_plan: LogicalPlan) -> SnowflakePlan:
# Selectable doesn't have children. It already has the expr_to_alias dict.
self.alias_maps_to_use = logical_plan.expr_to_alias.copy()
else:
use_maps = {}
# get counts of expr_to_alias keys
counts = Counter()
for v in resolved_children.values():
if v.expr_to_alias:
counts.update(list(v.expr_to_alias.keys()))

# Keep only non-shared expr_to_alias keys
# let (df1.join(df2)).join(df2.join(df3)).select(df2) report error
for v in resolved_children.values():
if v.expr_to_alias:
use_maps.update(
{p: q for p, q in v.expr_to_alias.items() if counts[p] < 2}
)
if self.session._resolve_conflict_alias:
self.alias_maps_to_use = merge_multiple_snowflake_plan_expr_to_alias(
list(resolved_children.values())
)
else:
use_maps = {}
# get counts of expr_to_alias keys
counts = Counter()
for v in resolved_children.values():
if v.expr_to_alias:
counts.update(list(v.expr_to_alias.keys()))

# Keep only non-shared expr_to_alias keys
# let (df1.join(df2)).join(df2.join(df3)).select(df2) report error
for v in resolved_children.values():
if v.expr_to_alias:
use_maps.update(
{p: q for p, q in v.expr_to_alias.items() if counts[p] < 2}
)

self.alias_maps_to_use = use_maps
self.alias_maps_to_use = use_maps

res = self.do_resolve_with_resolved_children(
logical_plan, resolved_children, df_aliased_col_name_to_real_col_name
Expand Down
9 changes: 8 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,17 @@ class NamedExpression:
name: str
_expr_id: Optional[uuid.UUID] = None

id = 0

@staticmethod
def get_next_id():
NamedExpression.id += 1
return NamedExpression.id

@property
def expr_id(self) -> uuid.UUID:
if not self._expr_id:
self._expr_id = uuid.uuid4()
self._expr_id = NamedExpression.get_next_id()
return self._expr_id

def __copy__(self):
Expand Down
14 changes: 11 additions & 3 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,9 +1745,17 @@ def derive_column_states_from_subquery(
raise SnowparkClientExceptionMessages.DF_ALIAS_NOT_RECOGNIZED(
c.child.df_alias
)
aliased_cols = from_.df_aliased_col_name_to_real_col_name[
c.child.df_alias
].values()
if not analyzer.session._resolve_conflict_alias:
aliased_cols = from_.df_aliased_col_name_to_real_col_name[
c.child.df_alias
].values()
else:
aliased_cols = [
v[0]
for v in from_.df_aliased_col_name_to_real_col_name[
c.child.df_alias
].values()
]
columns_from_star = [
copy(e)
for e in from_.column_states.projection
Expand Down
54 changes: 36 additions & 18 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
generate_random_alphanumeric,
get_copy_into_table_options,
is_sql_select_statement,
merge_multiple_snowflake_plan_expr_to_alias,
random_name_for_temp_object,
)
from snowflake.snowpark.row import Row
Expand Down Expand Up @@ -158,13 +159,24 @@ def wrap(*args, **kwargs):
children = [
arg for arg in args if isinstance(arg, SnowflakePlan)
]
remapped = [
SnowflakePlan.Decorator.__wrap_exception_regex_sub.sub(
"", val
)
for child in children
for val in child.expr_to_alias.values()
]
remapped = []
if children:
if not children[0].session._resolve_conflict_alias:
remapped = [
SnowflakePlan.Decorator.__wrap_exception_regex_sub.sub(
"", val
)
for child in children
for val in child.expr_to_alias.values()
]
else:
remapped = [
SnowflakePlan.Decorator.__wrap_exception_regex_sub.sub(
"", val[0]
)
for child in children
for val in child.expr_to_alias.values()
]
if col in remapped:
unaliased_cols = (
snowflake.snowpark.dataframe._get_unaliased(col)
Expand Down Expand Up @@ -581,17 +593,23 @@ def build_binary(
right_schema_query = schema_value_statement(select_right.attributes)
schema_query = sql_generator(left_schema_query, right_schema_query)

common_columns = set(select_left.expr_to_alias.keys()).intersection(
select_right.expr_to_alias.keys()
)
new_expr_to_alias = {
k: v
for k, v in {
**select_left.expr_to_alias,
**select_right.expr_to_alias,
}.items()
if k not in common_columns
}
if self.session._resolve_conflict_alias:
new_expr_to_alias = merge_multiple_snowflake_plan_expr_to_alias(
[select_left, select_right]
)
else:
common_columns = set(select_left.expr_to_alias.keys()).intersection(
select_right.expr_to_alias.keys()
)
new_expr_to_alias = {
k: v
for k, v in {
**select_left.expr_to_alias,
**select_right.expr_to_alias,
}.items()
if k not in common_columns
}

api_calls = [*select_left.api_calls, *select_right.api_calls]

# Need to do a deduplication to avoid repeated query.
Expand Down
78 changes: 78 additions & 0 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sys
import threading
import traceback
import uuid
import zipfile
from enum import Enum, IntEnum, auto, unique
from functools import lru_cache
Expand Down Expand Up @@ -50,12 +51,15 @@
from snowflake.connector.description import OPERATING_SYSTEM, PLATFORM
from snowflake.connector.options import MissingOptionalDependency, ModuleLikeObject
from snowflake.connector.version import VERSION as connector_version

from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark.context import _should_use_structured_type_semantics
from snowflake.snowpark.row import Row
from snowflake.snowpark.version import VERSION as snowpark_version

if TYPE_CHECKING:
from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan

try:
from snowflake.connector.cursor import ResultMetadataV2
except ImportError:
Expand Down Expand Up @@ -1510,3 +1514,77 @@ def next(self) -> int:


global_counter: GlobalCounter = GlobalCounter()


def merge_multiple_snowflake_plan_expr_to_alias(
snowflake_plans: List["SnowflakePlan"],
) -> Dict[uuid.UUID, str]:
"""
Merges expression-to-alias mappings from multiple Snowflake plans, resolving conflicts where possible.

Args:
snowflake_plans (List[SnowflakePlan]): List of SnowflakePlan objects.

Returns:
Dict[Any, str]: Merged expression-to-alias mapping.
"""

# Gather all expression-to-alias mappings
all_expr_to_alias_dicts = [plan.expr_to_alias for plan in snowflake_plans]

# Initialize the merged dictionary
merged_dict = {}

# Collect all unique keys from all dictionaries
all_keys = set().union(*all_expr_to_alias_dicts)

conflicted_keys = {}

for key in all_keys:
# Gather all aliases for the current key
values = list({d[key] for d in all_expr_to_alias_dicts if key in d})
# Check if all aliases are identical
if len(values) == 1:
merged_dict[key] = values[0]
elif len(values) == 2 and values[0][0] == values[1][0]:
# alias_name is equal, is_back_propagated is different, we pick the not back propagated bool
merged_dict[key] = (values[0][0], values[0][1] or values[1][1])
else:
conflicted_keys[key] = values

if not conflicted_keys:
return merged_dict

for key in conflicted_keys:
candidate = None
candidate_is_back_propagated = False
for plan in snowflake_plans:
output_columns = [attr.name for attr in plan.output if plan.schema_query]
tmp_alias_name, tmp_is_back_propagated = plan.expr_to_alias[key]
if tmp_alias_name not in output_columns or tmp_is_back_propagated:
# back propagated columns are not considered as they are not used in the output
# check Analyzer.unary_expression_extractor functions
continue
if not candidate:
candidate = tmp_alias_name
candidate_is_back_propagated = tmp_is_back_propagated
else:
if candidate == tmp_alias_name:
# The candidate is the same as the current alias
candidate_is_back_propagated = (
candidate_is_back_propagated or tmp_is_back_propagated
)
else:
# The candidate is different from the current alias, ambiguous
candidate = None
# Add the candidate to the merged dictionary if resolved
if candidate is not None:
merged_dict[key] = (candidate, candidate_is_back_propagated)
else:
# No valid candidate found
_logger.debug(
f"Expression '{key}' is associated with multiple aliases across different plans. "
f"Unable to determine which alias to use. Conflicting values: {conflicted_keys[key]}"
)

return merged_dict
Loading
Loading