From 55cc36f416ce35fd3cfef033b0292aad880f95ee Mon Sep 17 00:00:00 2001 From: Anant Aneja <1797669+aaneja@users.noreply.github.com> Date: Fri, 22 Sep 2023 22:35:18 +0530 Subject: [PATCH] Revert "Enhance join reordering to work with non-simple equi-join clauses" This reverts commit 009c06bfcd7e7b54c8b15fab5eb60e2b34f79d9f. --- .../resources/sql/presto/tpcds/q02.plan.txt | 33 +-- .../resources/sql/presto/tpcds/q59.plan.txt | 4 +- .../expressions/LogicalRowExpressions.java | 12 - .../presto/SystemSessionProperties.java | 11 - .../presto/sql/analyzer/FeaturesConfig.java | 14 - .../planner/iterative/rule/ReorderJoins.java | 251 +++--------------- .../sql/analyzer/TestFeaturesConfig.java | 7 +- .../presto/sql/planner/TestDynamicFilter.java | 1 - .../assertions/RowExpressionVerifier.java | 4 +- .../iterative/rule/TestJoinEnumerator.java | 122 +-------- .../iterative/rule/TestJoinNodeFlattener.java | 74 +----- .../iterative/rule/TestReorderJoins.java | 156 +---------- .../presto/tests/AbstractTestJoinQueries.java | 24 -- .../tests/AbstractTestQueryFramework.java | 6 - 14 files changed, 77 insertions(+), 642 deletions(-) diff --git a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt index 8418d929db9d..44f85122d99f 100644 --- a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt +++ b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q02.plan.txt @@ -2,22 +2,6 @@ remote exchange (GATHER, SINGLE, []) local exchange (GATHER, UNKNOWN, []) remote exchange (REPARTITION, ROUND_ROBIN, []) join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, [subtract_400]) - join (INNER, PARTITIONED): - final aggregation over (d_week_seq_232) - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [d_week_seq_232]) - partial aggregation over (d_week_seq_232) - join (INNER, REPLICATED): - remote exchange (REPARTITION, ROUND_ROBIN, []) - scan web_sales - scan catalog_sales - local exchange (GATHER, SINGLE, []) - remote exchange (REPLICATE, BROADCAST, []) - scan date_dim - local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [d_week_seq_316]) - scan date_dim join (INNER, PARTITIONED): final aggregation over (d_week_seq) local exchange (GATHER, SINGLE, []) @@ -33,3 +17,20 @@ remote exchange (GATHER, SINGLE, []) local exchange (GATHER, SINGLE, []) remote exchange (REPARTITION, HASH, [d_week_seq_83]) scan date_dim + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [subtract]) + join (INNER, PARTITIONED): + final aggregation over (d_week_seq_232) + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [d_week_seq_232]) + partial aggregation over (d_week_seq_232) + join (INNER, REPLICATED): + remote exchange (REPARTITION, ROUND_ROBIN, []) + scan web_sales + scan catalog_sales + local exchange (GATHER, SINGLE, []) + remote exchange (REPLICATE, BROADCAST, []) + scan date_dim + local exchange (GATHER, SINGLE, []) + remote exchange (REPARTITION, HASH, [d_week_seq_316]) + scan date_dim diff --git a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt index e9fc6a9c0d2e..23c58414ed69 100644 --- a/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt +++ b/presto-benchto-benchmarks/src/test/resources/sql/presto/tpcds/q59.plan.txt @@ -1,7 +1,7 @@ local exchange (GATHER, SINGLE, []) remote exchange (GATHER, SINGLE, []) join (INNER, PARTITIONED): - remote exchange (REPARTITION, HASH, [d_week_seq, d_week_seq_267, s_store_id]) + remote exchange (REPARTITION, HASH, [d_week_seq, s_store_id]) join (INNER, REPLICATED): join (INNER, REPLICATED): final aggregation over (d_week_seq, ss_store_sk) @@ -20,7 +20,7 @@ local exchange (GATHER, SINGLE, []) remote exchange (REPLICATE, BROADCAST, []) scan store local exchange (GATHER, SINGLE, []) - remote exchange (REPARTITION, HASH, [d_week_seq_147, d_week_seq_63, s_store_id_235]) + remote exchange (REPARTITION, HASH, [s_store_id_235, subtract]) join (INNER, REPLICATED): join (INNER, REPLICATED): final aggregation over (d_week_seq_147, ss_store_sk_127) diff --git a/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java b/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java index 03cf107b1564..c00c366f3b23 100644 --- a/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java +++ b/presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java @@ -26,7 +26,6 @@ import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.SpecialFormExpression.Form; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.google.common.collect.ImmutableSet; import java.util.ArrayDeque; import java.util.ArrayList; @@ -523,17 +522,6 @@ public ConvertNormalFormVisitorContext childContext() } } - private static class VariableReferenceBuilderVisitor - extends DefaultRowExpressionTraversalVisitor> - { - @Override - public Void visitVariableReference(VariableReferenceExpression variable, ImmutableSet.Builder builder) - { - builder.add(variable); - return null; - } - } - private class ConvertNormalFormVisitor implements RowExpressionVisitor { diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index f89b53798b25..3454025a418a 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -287,7 +287,6 @@ public final class SystemSessionProperties public static final String PULL_EXPRESSION_FROM_LAMBDA_ENABLED = "pull_expression_from_lambda_enabled"; public static final String REWRITE_CONSTANT_ARRAY_CONTAINS_TO_IN_EXPRESSION = "rewrite_constant_array_contains_to_in_expression"; public static final String INFER_INEQUALITY_PREDICATES = "infer_inequality_predicates"; - public static final String HANDLE_COMPLEX_EQUI_JOINS = "handle_complex_equi_joins"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_SIMPLIFIED_EXPRESSION_EVALUATION_ENABLED = "simplified_expression_evaluation_enabled"; @@ -1674,11 +1673,6 @@ public SystemSessionProperties( INFER_INEQUALITY_PREDICATES, "Infer nonequality predicates for joins", featuresConfig.getInferInequalityPredicates(), - false), - booleanProperty( - HANDLE_COMPLEX_EQUI_JOINS, - "Handle complex equi-join conditions to open up join space for join reordering", - featuresConfig.getHandleComplexEquiJoins(), false)); } @@ -2827,9 +2821,4 @@ public static boolean shouldInferInequalityPredicates(Session session) { return session.getSystemProperty(INFER_INEQUALITY_PREDICATES, Boolean.class); } - - public static boolean shouldHandleComplexEquiJoins(Session session) - { - return session.getSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, Boolean.class); - } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 6e2dd6353872..20d4ed4985f6 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -277,7 +277,6 @@ public class FeaturesConfig private boolean rewriteConstantArrayContainsToIn; private boolean preProcessMetadataCalls; - private boolean handleComplexEquiJoins = true; public enum PartitioningPrecisionStrategy { @@ -2740,17 +2739,4 @@ public FeaturesConfig setRewriteConstantArrayContainsToInEnabled(boolean rewrite this.rewriteConstantArrayContainsToIn = rewriteConstantArrayContainsToIn; return this; } - - public boolean getHandleComplexEquiJoins() - { - return handleComplexEquiJoins; - } - - @Config("optimizer.handle-complex-equi-joins") - @ConfigDescription("Handle complex equi-join conditions to open up join space for join reordering") - public FeaturesConfig setHandleComplexEquiJoins(boolean handleComplexEquiJoins) - { - this.handleComplexEquiJoins = handleComplexEquiJoins; - return this; - } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java index aea88ca5b81a..fa86d5abf501 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java @@ -24,12 +24,10 @@ import com.facebook.presto.matching.Pattern; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.VariableAllocator; -import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeId; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.RowExpression; @@ -37,6 +35,7 @@ import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.planner.CanonicalJoinNode; import com.facebook.presto.sql.planner.EqualityInference; +import com.facebook.presto.sql.planner.VariablesExtractor; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.JoinNode; @@ -54,7 +53,6 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -68,15 +66,12 @@ import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType; import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy; import static com.facebook.presto.SystemSessionProperties.getMaxReorderedJoins; -import static com.facebook.presto.SystemSessionProperties.shouldHandleComplexEquiJoins; +import static com.facebook.presto.common.function.OperatorType.EQUAL; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; import static com.facebook.presto.expressions.LogicalRowExpressions.and; import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts; -import static com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.AUTOMATIC; import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference; -import static com.facebook.presto.sql.planner.PlannerUtils.addProjections; -import static com.facebook.presto.sql.planner.VariablesExtractor.extractUnique; import static com.facebook.presto.sql.planner.iterative.rule.DetermineJoinDistributionType.isBelowMaxBroadcastSize; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT; @@ -140,8 +135,7 @@ public boolean isEnabled(Session session) @Override public Result apply(JoinNode joinNode, Captures captures, Context context) { - MultiJoinNode multiJoinNode = toMultiJoinNode(joinNode, context.getLookup(), getMaxReorderedJoins(context.getSession()), shouldHandleComplexEquiJoins(context.getSession()), - functionResolution, determinismEvaluator); + MultiJoinNode multiJoinNode = toMultiJoinNode(joinNode, context.getLookup(), getMaxReorderedJoins(context.getSession()), functionResolution, determinismEvaluator); JoinEnumerator joinEnumerator = new JoinEnumerator( costComparator, multiJoinNode.getFilter(), @@ -149,19 +143,11 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) determinismEvaluator, functionResolution, metadata); - JoinEnumerationResult result = joinEnumerator.chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputVariables()); - if (!result.getPlanNode().isPresent()) { return Result.empty(); } - - PlanNode transformedPlan = result.getPlanNode().get(); - if (multiJoinNode.getAssignments().isPresent()) { - transformedPlan = addProjections(transformedPlan, context.getIdAllocator(), multiJoinNode.getAssignments().get().getMap()); - } - - return Result.ofPlanNode(transformedPlan); + return Result.ofPlanNode(result.getPlanNode().get()); } @VisibleForTesting @@ -180,7 +166,6 @@ static class JoinEnumerator private final Context context; private final Map, JoinEnumerationResult> memo = new HashMap<>(); - private final FunctionResolution functionResolution; @VisibleForTesting JoinEnumerator(CostComparator costComparator, RowExpression filter, Context context, DeterminismEvaluator determinismEvaluator, FunctionResolution functionResolution, Metadata metadata) @@ -196,7 +181,6 @@ static class JoinEnumerator this.metadata = requireNonNull(metadata, "metadata is null"); this.allFilterInference = createEqualityInference(metadata, filter); this.logicalRowExpressions = new LogicalRowExpressions(determinismEvaluator, functionResolution, metadata.getFunctionAndTypeManager()); - this.functionResolution = functionResolution; } private JoinEnumerationResult chooseJoinOrder(LinkedHashSet sources, List outputVariables) @@ -271,31 +255,28 @@ JoinEnumerationResult createJoinAccordingToPartitioning(LinkedHashSet private JoinEnumerationResult createJoin(LinkedHashSet leftSources, LinkedHashSet rightSources, List outputVariables) { - HashSet leftVariables = leftSources.stream() + Set leftVariables = leftSources.stream() .flatMap(node -> node.getOutputVariables().stream()) - .collect(toCollection(HashSet::new)); - HashSet rightVariables = rightSources.stream() + .collect(toImmutableSet()); + Set rightVariables = rightSources.stream() .flatMap(node -> node.getOutputVariables().stream()) - .collect(toCollection(HashSet::new)); + .collect(toImmutableSet()); List joinPredicates = getJoinPredicates(leftVariables, rightVariables); - - VariableAllocator variableAllocator = context.getVariableAllocator(); - JoinCondition joinConditions = extractJoinConditions(joinPredicates, leftVariables, rightVariables, variableAllocator); - List joinClauses = joinConditions.getJoinClauses(); - List joinFilters = joinConditions.getJoinFilters(); - - //Update the left & right variable sets with any new variables generated - leftVariables.addAll(joinConditions.getNewLeftAssignments().keySet()); - rightVariables.addAll(joinConditions.getNewRightAssignments().keySet()); - - if (joinClauses.isEmpty()) { + List joinConditions = joinPredicates.stream() + .filter(JoinEnumerator::isJoinEqualityCondition) + .map(predicate -> toEquiJoinClause((CallExpression) predicate, leftVariables, context.getVariableAllocator())) + .collect(toImmutableList()); + if (joinConditions.isEmpty()) { return INFINITE_COST_RESULT; } + List joinFilters = joinPredicates.stream() + .filter(predicate -> !isJoinEqualityCondition(predicate)) + .collect(toImmutableList()); Set requiredJoinVariables = ImmutableSet.builder() .addAll(outputVariables) - .addAll(extractUnique(joinPredicates)) + .addAll(VariablesExtractor.extractUnique(joinPredicates)) .build(); JoinEnumerationResult leftResult = getJoinSource( @@ -311,13 +292,6 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li } PlanNode left = leftResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present")); - if (!joinConditions.getNewLeftAssignments().isEmpty()) { - ImmutableMap.Builder assignments = ImmutableMap.builder(); - left.getOutputVariables().forEach(outputVariable -> assignments.put(outputVariable, outputVariable)); - assignments.putAll(joinConditions.getNewLeftAssignments()); - - left = addProjections(left, idAllocator, assignments.build()); - } JoinEnumerationResult rightResult = getJoinSource( rightSources, @@ -332,13 +306,6 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li } PlanNode right = rightResult.planNode.orElseThrow(() -> new VerifyException("Plan node is not present")); - if (!joinConditions.getNewRightAssignments().isEmpty()) { - ImmutableMap.Builder assignments = ImmutableMap.builder(); - right.getOutputVariables().forEach(outputVariable -> assignments.put(outputVariable, outputVariable)); - assignments.putAll(joinConditions.getNewRightAssignments()); - - right = addProjections(right, idAllocator, assignments.build()); - } // sort output variables so that the left input variables are first List sortedOutputVariables = Stream.concat(left.getOutputVariables().stream(), right.getOutputVariables().stream()) @@ -351,7 +318,7 @@ private JoinEnumerationResult createJoin(LinkedHashSet leftSources, Li INNER, left, right, - joinClauses, + joinConditions, sortedOutputVariables, joinFilters.isEmpty() ? Optional.empty() : Optional.of(and(joinFilters)), Optional.empty(), @@ -407,103 +374,22 @@ private JoinEnumerationResult getJoinSource(LinkedHashSet nodes, List< return chooseJoinOrder(nodes, outputVariables); } - @VisibleForTesting - JoinCondition extractJoinConditions(List joinPredicates, - Set leftVariables, - Set rightVariables, - VariableAllocator variableAllocator) + private static boolean isJoinEqualityCondition(RowExpression expression) { - ImmutableMap.Builder newLeftAssignments = ImmutableMap.builder(); - ImmutableMap.Builder newRightAssignments = ImmutableMap.builder(); - - ImmutableList.Builder joinClauses = ImmutableList.builder(); - ImmutableList.Builder joinFilters = ImmutableList.builder(); - - for (RowExpression predicate : joinPredicates) { - if (predicate instanceof CallExpression - && functionResolution.isEqualFunction(((CallExpression) predicate).getFunctionHandle()) - && ((CallExpression) predicate).getArguments().size() == 2) { - RowExpression argument0 = ((CallExpression) predicate).getArguments().get(0); - RowExpression argument1 = ((CallExpression) predicate).getArguments().get(1); - - // First check if arguments refer to different sides of join - Set argument0Vars = extractUnique(argument0); - Set argument1Vars = extractUnique(argument1); - if (!((leftVariables.containsAll(argument0Vars) && rightVariables.containsAll(argument1Vars)) - || (rightVariables.containsAll(argument0Vars) && leftVariables.containsAll(argument1Vars)))) { - // Arguments have a mix of join sides, use this predicate as a filter - joinFilters.add(predicate); - continue; - } - - // Next, check to see if first argument refers to left side and second argument to the right side - // If not, swap the arguments - if (leftVariables.containsAll(argument1Vars)) { - RowExpression temp = argument1; - argument1 = argument0; - argument0 = temp; - } - - // Next, check if we need to create new assignments for complex equi-join clauses - // E.g. leftVar = ADD(rightVar1, rightVar2) - if (!(argument0 instanceof VariableReferenceExpression)) { - VariableReferenceExpression newLeft = variableAllocator.newVariable(argument0); - newLeftAssignments.put(newLeft, argument0); - argument0 = newLeft; - } - - if (!(argument1 instanceof VariableReferenceExpression)) { - VariableReferenceExpression newRight = variableAllocator.newVariable(argument1); - newRightAssignments.put(newRight, argument1); - argument1 = newRight; - } - - joinClauses.add(new EquiJoinClause((VariableReferenceExpression) argument0, (VariableReferenceExpression) argument1)); - } - else { - joinFilters.add(predicate); - } - } - - return new JoinCondition(joinClauses.build(), joinFilters.build(), newLeftAssignments.build(), newRightAssignments.build()); + return expression instanceof CallExpression + && ((CallExpression) expression).getDisplayName().equals(EQUAL.getFunctionName().getObjectName()) + && ((CallExpression) expression).getArguments().size() == 2 + && ((CallExpression) expression).getArguments().get(0) instanceof VariableReferenceExpression + && ((CallExpression) expression).getArguments().get(1) instanceof VariableReferenceExpression; } - @VisibleForTesting - static class JoinCondition + private static EquiJoinClause toEquiJoinClause(CallExpression equality, Set leftVariables, VariableAllocator variableAllocator) { - List joinClauses; - List joinFilters; - Map newLeftAssignments; - Map newRightAssignments; - - public JoinCondition(List joinClauses, List joinFilters, - Map left, Map right) - { - this.joinClauses = joinClauses; - this.joinFilters = joinFilters; - this.newLeftAssignments = left; - this.newRightAssignments = right; - } - - public List getJoinClauses() - { - return joinClauses; - } - - public List getJoinFilters() - { - return joinFilters; - } - - public Map getNewLeftAssignments() - { - return newLeftAssignments; - } - - public Map getNewRightAssignments() - { - return newRightAssignments; - } + checkArgument(equality.getArguments().size() == 2, "Unexpected number of arguments in binary operator equals"); + VariableReferenceExpression leftVariable = (VariableReferenceExpression) equality.getArguments().get(0); + VariableReferenceExpression rightVariable = (VariableReferenceExpression) equality.getArguments().get(1); + EquiJoinClause equiJoinClause = new EquiJoinClause(leftVariable, rightVariable); + return leftVariables.contains(leftVariable) ? equiJoinClause : equiJoinClause.flip(); } private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) @@ -573,10 +459,8 @@ static class MultiJoinNode { // Use a linked hash set to ensure optimizer is deterministic private final CanonicalJoinNode node; - private final Optional assignments; - public MultiJoinNode(LinkedHashSet sources, RowExpression filter, List outputVariables, - Assignments assignments) + public MultiJoinNode(LinkedHashSet sources, RowExpression filter, List outputVariables) { checkArgument(sources.size() > 1, "sources size is <= 1"); @@ -592,23 +476,8 @@ public MultiJoinNode(LinkedHashSet sources, RowExpression filter, List ImmutableSet.of(filter), outputVariables); - ImmutableSet inputVariables = sources.stream().flatMap(source -> source.getOutputVariables().stream()).collect(toImmutableSet()); - // We could have some output variables that were possibly generated from intermediate projects that were removed - // We will need to create a wrapper Project to add them back - Assignments.Builder builder = Assignments.builder(); - boolean nonIdentityAssignmentsFound = false; - for (VariableReferenceExpression outputVariable : outputVariables) { - if (inputVariables.contains(outputVariable)) { - builder.put(outputVariable, outputVariable); - continue; - } - checkState(assignments.getMap().containsKey(outputVariable), - "Output variable [%s] not found in input variables or intermediate assignments", outputVariable); - nonIdentityAssignmentsFound = true; - builder.put(outputVariable, assignments.get(outputVariable)); - } - - this.assignments = nonIdentityAssignmentsFound ? Optional.of(builder.build()) : Optional.empty(); + List inputVariables = sources.stream().flatMap(source -> source.getOutputVariables().stream()).collect(toImmutableList()); + checkArgument(inputVariables.containsAll(outputVariables), "inputs do not contain all output variables"); } public RowExpression getFilter() @@ -626,11 +495,6 @@ public List getOutputVariables() return node.getOutputVariables(); } - public Optional getAssignments() - { - return assignments; - } - public static Builder builder() { return new Builder(); @@ -655,25 +519,22 @@ public boolean equals(Object obj) && getOutputVariables().equals(other.getOutputVariables()); } - static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int joinLimit, boolean handleComplexEquiJoins, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) + static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup, int joinLimit, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) { // the number of sources is the number of joins + 1 - return new JoinNodeFlattener(joinNode, lookup, joinLimit + 1, handleComplexEquiJoins, functionResolution, determinismEvaluator).toMultiJoinNode(); + return new JoinNodeFlattener(joinNode, lookup, joinLimit + 1, functionResolution, determinismEvaluator).toMultiJoinNode(); } private static class JoinNodeFlattener { private final LinkedHashSet sources = new LinkedHashSet<>(); - private final Assignments assignments; - private final boolean handleComplexEquiJoins; - private List filters = new ArrayList<>(); + private final List filters = new ArrayList<>(); private final List outputVariables; private final FunctionResolution functionResolution; private final DeterminismEvaluator determinismEvaluator; private final Lookup lookup; - JoinNodeFlattener(JoinNode node, Lookup lookup, int sourceLimit, boolean handleComplexEquiJoins, FunctionResolution functionResolution, - DeterminismEvaluator determinismEvaluator) + JoinNodeFlattener(JoinNode node, Lookup lookup, int sourceLimit, FunctionResolution functionResolution, DeterminismEvaluator determinismEvaluator) { requireNonNull(node, "node is null"); checkState(node.getType() == INNER, "join type must be INNER"); @@ -681,39 +542,13 @@ private static class JoinNodeFlattener this.lookup = requireNonNull(lookup, "lookup is null"); this.functionResolution = requireNonNull(functionResolution, "functionResolution is null"); this.determinismEvaluator = requireNonNull(determinismEvaluator, "determinismEvaluator is null"); - this.handleComplexEquiJoins = handleComplexEquiJoins; - Assignments.Builder intermediateAssignments = Assignments.builder(); - flattenNode(node, sourceLimit, intermediateAssignments); - this.assignments = intermediateAssignments.build(); - rewriteFilterWithInlinedAssignments(intermediateAssignments.build()); - } - - private void rewriteFilterWithInlinedAssignments(Assignments assignments) - { - ImmutableList.Builder modifiedFilters = ImmutableList.builder(); - filters.forEach(filter -> modifiedFilters.add(replaceExpression(filter, assignments.getMap()))); - filters = modifiedFilters.build(); + flattenNode(node, sourceLimit); } - private void flattenNode(PlanNode node, int limit, Assignments.Builder assignmentsBuilder) + private void flattenNode(PlanNode node, int limit) { PlanNode resolved = lookup.resolve(node); - if (resolved instanceof ProjectNode) { - ProjectNode projectNode = (ProjectNode) resolved; - // A ProjectNode could be 'hiding' a join source by building an assignment of a complex equi-join criteria like `left.key = right1.key1 + right1.key2` - // We open up the join space by tracking the assignments from this Project node; these will be inlined into the overall filters once we finish - // traversing the join graph - if (handleComplexEquiJoins && lookup.resolve(projectNode.getSource()) instanceof JoinNode) { - assignmentsBuilder.putAll(projectNode.getAssignments()); - flattenNode(projectNode.getSource(), limit, assignmentsBuilder); - } - else { - sources.add(node); - } - return; - } - // (limit - 2) because you need to account for adding left and right side if (!(resolved instanceof JoinNode) || (sources.size() > (limit - 2))) { sources.add(node); @@ -727,8 +562,8 @@ private void flattenNode(PlanNode node, int limit, Assignments.Builder assignmen } // we set the left limit to limit - 1 to account for the node on the right - flattenNode(joinNode.getLeft(), limit - 1, assignmentsBuilder); - flattenNode(joinNode.getRight(), limit, assignmentsBuilder); + flattenNode(joinNode.getLeft(), limit - 1); + flattenNode(joinNode.getRight(), limit); joinNode.getCriteria().stream() .map(criteria -> toRowExpression(criteria, functionResolution)) .forEach(filters::add); @@ -737,7 +572,7 @@ private void flattenNode(PlanNode node, int limit, Assignments.Builder assignmen MultiJoinNode toMultiJoinNode() { - return new MultiJoinNode(sources, and(filters), outputVariables, assignments); + return new MultiJoinNode(sources, and(filters), outputVariables); } } @@ -767,7 +602,7 @@ public Builder setOutputVariables(VariableReferenceExpression... outputVariables public MultiJoinNode build() { - return new MultiJoinNode(new LinkedHashSet<>(sources), filter, outputVariables, Assignments.builder().build()); + return new MultiJoinNode(new LinkedHashSet<>(sources), filter, outputVariables); } } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index ef82b90f748b..a7959165dc0a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -242,8 +242,7 @@ public void testDefaults() .setAddPartialNodeForRowNumberWithLimitEnabled(true) .setInferInequalityPredicates(false) .setPullUpExpressionFromLambdaEnabled(false) - .setRewriteConstantArrayContainsToInEnabled(false) - .setHandleComplexEquiJoins(true)); + .setRewriteConstantArrayContainsToInEnabled(false)); } @Test @@ -434,7 +433,6 @@ public void testExplicitPropertyMappings() .put("optimizer.infer-inequality-predicates", "true") .put("optimizer.pull-up-expression-from-lambda", "true") .put("optimizer.rewrite-constant-array-contains-to-in", "true") - .put("optimizer.handle-complex-equi-joins", "false") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -622,8 +620,7 @@ public void testExplicitPropertyMappings() .setAddPartialNodeForRowNumberWithLimitEnabled(false) .setInferInequalityPredicates(true) .setPullUpExpressionFromLambdaEnabled(true) - .setRewriteConstantArrayContainsToInEnabled(true) - .setHandleComplexEquiJoins(false); + .setRewriteConstantArrayContainsToInEnabled(true); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java index 4e386ab85427..63dea91103d9 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDynamicFilter.java @@ -439,7 +439,6 @@ public void testNonPushedDownJoinFilterRemoval() "SELECT 1 FROM part t0, part t1, part t2 " + "WHERE t0.partkey = t1.partkey AND t0.partkey = t2.partkey " + "AND t0.size + t1.size = t2.size", - noJoinReordering(), anyTree( join( INNER, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java index 062cca4bbc5c..9d0f7175ad91 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java @@ -100,7 +100,7 @@ /** * RowExpression visitor which verifies if given expression (actual) is matching other RowExpression given as context (expected). */ -public final class RowExpressionVerifier +final class RowExpressionVerifier extends AstVisitor { // either use variable or input reference for symbol mapping @@ -110,7 +110,7 @@ public final class RowExpressionVerifier private final FunctionResolution functionResolution; private final Set lambdaArguments; - public RowExpressionVerifier(SymbolAliases symbolAliases, Metadata metadata, Session session) + RowExpressionVerifier(SymbolAliases symbolAliases, Metadata metadata, Session session) { this.symbolAliases = requireNonNull(symbolAliases, "symbolLayout is null"); this.metadata = requireNonNull(metadata, "metadata is null"); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java index 37e2f15062be..d60ea7b95e8a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -15,69 +15,46 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; -import com.facebook.presto.common.type.Type; import com.facebook.presto.cost.CachingCostProvider; import com.facebook.presto.cost.CachingStatsProvider; import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.CostProvider; import com.facebook.presto.cost.PlanCostEstimate; import com.facebook.presto.cost.StatsProvider; -import com.facebook.presto.expressions.LogicalRowExpressions; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; -import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.LogicalPropertiesProvider; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.relation.DeterminismEvaluator; -import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.TestingRowExpressionTranslator; import com.facebook.presto.sql.planner.TypeProvider; -import com.facebook.presto.sql.planner.assertions.RowExpressionVerifier; -import com.facebook.presto.sql.planner.assertions.SymbolAliases; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator; -import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.JoinCondition; import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.MultiJoinNode; import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; -import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.testing.LocalQueryRunner; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import java.util.Arrays; import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; -import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; -import static com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression; import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup; import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.generatePartitions; -import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression; -import static com.facebook.presto.sql.planner.optimizations.JoinNodeUtils.toRowExpression; -import static com.facebook.presto.sql.relational.Expressions.variable; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertNotNull; -import static org.testng.Assert.assertNull; -import static org.testng.Assert.assertTrue; public class TestJoinEnumerator { @@ -85,20 +62,14 @@ public class TestJoinEnumerator private Metadata metadata; private DeterminismEvaluator determinismEvaluator; private FunctionResolution functionResolution; - private PlanBuilder planBuilder; - private TestingRowExpressionTranslator rowExpressionTranslator; - private Session session; @BeforeClass public void setUp() { - session = testSessionBuilder().build(); - queryRunner = new LocalQueryRunner(session); + queryRunner = new LocalQueryRunner(testSessionBuilder().build()); metadata = queryRunner.getMetadata(); determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata); functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); - planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), metadata); - rowExpressionTranslator = new TestingRowExpressionTranslator(metadata); } @AfterClass(alwaysRun = true) @@ -138,8 +109,7 @@ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() MultiJoinNode multiJoinNode = new MultiJoinNode( new LinkedHashSet<>(ImmutableList.of(p.values(a1), p.values(b1))), TRUE_CONSTANT, - ImmutableList.of(a1, b1), - Assignments.builder().build()); + ImmutableList.of(a1, b1)); JoinEnumerator joinEnumerator = new JoinEnumerator( new CostComparator(1, 1, 1), multiJoinNode.getFilter(), @@ -152,94 +122,6 @@ public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() assertEquals(actual.getCost(), PlanCostEstimate.infinite()); } - @Test - public void testJoinClauseAndFilterInference() - { - ImmutableMap.Builder builder = ImmutableMap.builder(); - builder.put("a", BIGINT); - builder.put("b", BIGINT); - builder.put("c", BIGINT); - builder.put("d", BIGINT); - Map variableMap = builder.build(); - VariableReferenceExpression a = variable("a", variableMap.get("a")); - VariableReferenceExpression b = variable("b", variableMap.get("b")); - VariableReferenceExpression c = variable("c", variableMap.get("c")); - VariableReferenceExpression d = variable("d", variableMap.get("d")); - - SymbolAliases.Builder newAliases = SymbolAliases.builder(); - newAliases.put("A", new SymbolReference("a")); - newAliases.put("B", new SymbolReference("b")); - newAliases.put("C", new SymbolReference("c")); - newAliases.put("D", new SymbolReference("d")); - SymbolAliases symbolAliases = newAliases.build(); - - // Simple join predicates on variable references - assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B", null); - assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b", "c = d"), ImmutableSet.of(a, c), ImmutableSet.of(b, d), "A = B AND C = D", null); - // Complex join predicate - All variables must be from one join side to have the predicate be an equi-join clause - assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B + C", null); - // Left and right side designation can be switched - assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c"), ImmutableSet.of(b, c), ImmutableSet.of(a), "A = B + C", null); - assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c + 1"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B + C + 1", null); - assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c + 1"), ImmutableSet.of(b, c), ImmutableSet.of(a), "A = B + C + 1", null); - // If a predicate has a mix of variables from left & right sides, the predicate is treated as a filter - assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a + b = c"), ImmutableSet.of(a), ImmutableSet.of(b, c), null, "A + B = C"); - assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a + b = 1"), ImmutableSet.of(a), ImmutableSet.of(b), null, "A + B = 1"); - // Test with multiple equi-join conditions and filters - assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = ABS(b)", "a = ceil(b-c)", "b = c + 10"), - ImmutableSet.of(a), ImmutableSet.of(b, c), "A = abs(B) AND A = ceil(B-C)", "B = C + 10"); - } - - private List toRowExpressionList(Map variableTypeMap, String... predicates) - { - return Arrays.stream(predicates) - .map(p -> rowExpressionTranslator.translate(p, variableTypeMap)) - .collect(Collectors.toList()); - } - - private void assertJoinCondition(SymbolAliases symbolAliases, List joinPredicates, Set leftVariables, - Set rightVariables, String expectedEquiJoinClause, String expectedJoinFilter) - { - RowExpressionVerifier verifier = new RowExpressionVerifier(symbolAliases, metadata, session); - JoinEnumerator joinEnumerator = new JoinEnumerator( - new CostComparator(1, 1, 1), - TRUE_CONSTANT, - createContext(), - determinismEvaluator, - functionResolution, - metadata); - - JoinCondition joinConditions = joinEnumerator.extractJoinConditions(joinPredicates, - leftVariables, rightVariables, new VariableAllocator()); - - Optional equiJoinExpressionInlined = joinConditions.getJoinClauses().stream() - .map(criteria -> toRowExpression(criteria, functionResolution)) - // We may have made left or right assignments to build the equi join clause - // We inline these assignments for building the equi join clause to verify - .map(expression -> replaceExpression(expression, joinConditions.getNewLeftAssignments())) - .map(expression -> replaceExpression(expression, joinConditions.getNewRightAssignments())) - .reduce(LogicalRowExpressions::and); - - if (equiJoinExpressionInlined.isPresent()) { - assertNotNull(expectedEquiJoinClause); - assertTrue(verifier.process(expression(expectedEquiJoinClause), equiJoinExpressionInlined.get())); - } - else { - assertNull(expectedEquiJoinClause); - } - - Optional joinFilter = joinConditions.getJoinFilters().stream() - .reduce(LogicalRowExpressions::and); - - if (joinFilter.isPresent()) { - assertNotNull(expectedJoinFilter); - assertTrue(verifier.process(expression(expectedJoinFilter), joinFilter.get())); - } - else { - assertNull(expectedJoinFilter); - } - } - private Rule.Context createContext() { PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java index dd057fe41823..6ba4bc4f7c9b 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinNodeFlattener.java @@ -15,9 +15,7 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.common.function.OperatorType; -import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; -import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.ValuesNode; import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.RowExpression; @@ -88,7 +86,7 @@ public void testDoesNotAllowOuterJoin() ImmutableList.of(equiJoinClause(a1, b1)), ImmutableList.of(a1, b1), Optional.empty()); - toMultiJoinNode(outerJoin, noLookup(), DEFAULT_JOIN_LIMIT, true, functionResolution, determinismEvaluator); + toMultiJoinNode(outerJoin, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator); } @Test @@ -118,7 +116,7 @@ public void testDoesNotConvertNestedOuterJoins() .setSources(leftJoin, valuesC).setFilter(createEqualsExpression(a1, c1)) .setOutputVariables(a1, b1, c1) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, true, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator), expected); } @Test @@ -151,7 +149,7 @@ public void testRetainsOutputSymbols() .setFilter(and(createEqualsExpression(b1, c1), createEqualsExpression(a1, b1))) .setOutputVariables(a1, b1) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, true, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator), expected); } @Test @@ -208,9 +206,8 @@ public void testCombinesCriteriaAndFilters() MultiJoinNode expected = new MultiJoinNode( new LinkedHashSet<>(ImmutableList.of(valuesA, valuesB, valuesC)), and(createEqualsExpression(b1, c1), createEqualsExpression(a1, b1), bcFilter, abcFilter), - ImmutableList.of(a1, b1, b2, c1, c2), - Assignments.builder().build()); - assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, true, functionResolution, determinismEvaluator), expected); + ImmutableList.of(a1, b1, b2, c1, c2)); + assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, functionResolution, determinismEvaluator), expected); } @Test @@ -261,7 +258,7 @@ public void testConvertsBushyTrees() .setFilter(and(createEqualsExpression(a1, b1), createEqualsExpression(a1, c1), createEqualsExpression(d1, e1), createEqualsExpression(d2, e2), createEqualsExpression(b1, e1))) .setOutputVariables(a1, b1, c1, d1, d2, e1, e2) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, true, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, functionResolution, determinismEvaluator), expected); } @Test @@ -314,56 +311,10 @@ public void testMoreThanJoinLimit() .setFilter(and(createEqualsExpression(a1, c1), createEqualsExpression(b1, e1))) .setOutputVariables(a1, b1, c1, d1, d2, e1, e2) .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), 2, true, functionResolution, determinismEvaluator), expected); - } - - @Test - public void testProjectNodesBetweenJoinNodesAreFlattenedForComplexEquiJoins() - { - PlanBuilder p = planBuilder(); - VariableReferenceExpression a1 = p.variable("A1"); - VariableReferenceExpression b1 = p.variable("B1"); - VariableReferenceExpression c1 = p.variable("C1"); - VariableReferenceExpression sum = p.variable("SUM"); - - ValuesNode valuesA = p.values(a1); - ValuesNode valuesB = p.values(b1); - ValuesNode valuesC = p.values(c1); - Assignments sumAssignment = Assignments.builder().put(sum, createAddExpression(a1, b1)).build(); - - ProjectNode intermediateProject = p.project(sumAssignment, p.join( - INNER, - valuesA, - valuesB, - ImmutableList.of(equiJoinClause(a1, b1)), - ImmutableList.of(a1, b1), - Optional.empty())); - JoinNode joinNode = p.join( - INNER, - intermediateProject, - valuesC, - ImmutableList.of(equiJoinClause(sum, c1)), - ImmutableList.of(), - Optional.empty()); - - MultiJoinNode expected = MultiJoinNode.builder() - .setSources(valuesA, valuesB, valuesC) - .setFilter(and(createEqualsExpression(a1, b1), createEqualsExpression(createAddExpression(a1, b1), c1))) - .setOutputVariables() - .build(); - assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, /*handleComplexEquiJoins*/ true, functionResolution, determinismEvaluator), expected); - - // Negative test - when handleComplexEquiJoins = false, we have a split join space; the ProjectNode is not flattened - expected = MultiJoinNode.builder() - .setSources(intermediateProject, valuesC) - .setFilter(createEqualsExpression(sum, c1)) - .setOutputVariables() - .build(); - - assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, /*handleComplexEquiJoins*/ false, functionResolution, determinismEvaluator), expected); + assertEquals(toMultiJoinNode(joinNode, noLookup(), 2, functionResolution, determinismEvaluator), expected); } - private RowExpression createEqualsExpression(RowExpression left, RowExpression right) + private RowExpression createEqualsExpression(VariableReferenceExpression left, VariableReferenceExpression right) { return call( OperatorType.EQUAL.name(), @@ -372,15 +323,6 @@ private RowExpression createEqualsExpression(RowExpression left, RowExpression r ImmutableList.of(left, right)); } - private RowExpression createAddExpression(RowExpression left, RowExpression right) - { - return call( - OperatorType.ADD.name(), - functionResolution.arithmeticFunction(OperatorType.ADD, left.getType(), right.getType()), - BIGINT, - ImmutableList.of(left, right)); - } - private EquiJoinClause equiJoinClause(VariableReferenceExpression variable1, VariableReferenceExpression variable2) { return new EquiJoinClause(variable1, variable2); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java index 155c95d33b8e..c29c36667239 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java @@ -22,7 +22,6 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; -import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; @@ -33,14 +32,12 @@ import com.google.common.collect.ImmutableMap; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; -import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; import java.util.Optional; import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException; -import static com.facebook.presto.SystemSessionProperties.HANDLE_COMPLEX_EQUI_JOINS; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_MAX_BROADCAST_TABLE_SIZE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; @@ -52,13 +49,8 @@ import static com.facebook.presto.metadata.FunctionAndTypeManager.qualifyObjectName; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.AUTOMATIC; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.BROADCAST; -import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; -import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; -import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; -import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; -import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; @@ -67,7 +59,6 @@ import static com.facebook.presto.sql.relational.Expressions.variable; public class TestReorderJoins - extends BasePlanTest { private RuleTester tester; private FunctionResolution functionResolution; @@ -76,37 +67,6 @@ public class TestReorderJoins private static final ImmutableList> TWO_ROWS = ImmutableList.of(ImmutableList.of(), ImmutableList.of()); private static final QualifiedName RANDOM = QualifiedName.of("random"); - @DataProvider - public static Object[][] tableSpecificationPermutations() - { - return new Object[][] { - {"supplier s, partsupp ps, customer c, orders o"}, - {"supplier s, partsupp ps, orders o, customer c"}, - {"supplier s, customer c, partsupp ps, orders o"}, - {"supplier s, customer c, orders o, partsupp ps"}, - {"supplier s, orders o, partsupp ps, customer c"}, - {"supplier s, orders o, customer c, partsupp ps"}, - {"partsupp ps, supplier s, customer c, orders o"}, - {"partsupp ps, supplier s, orders o, customer c"}, - {"partsupp ps, customer c, supplier s, orders o"}, - {"partsupp ps, customer c, orders o, supplier s"}, - {"partsupp ps, orders o, supplier s, customer c"}, - {"partsupp ps, orders o, customer c, supplier s"}, - {"customer c, supplier s, partsupp ps, orders o"}, - {"customer c, supplier s, orders o, partsupp ps"}, - {"customer c, partsupp ps, supplier s, orders o"}, - {"customer c, partsupp ps, orders o, supplier s"}, - {"customer c, orders o, supplier s, partsupp ps"}, - {"customer c, orders o, partsupp ps, supplier s"}, - {"orders o, supplier s, partsupp ps, customer c"}, - {"orders o, supplier s, customer c, partsupp ps"}, - {"orders o, partsupp ps, supplier s, customer c"}, - {"orders o, partsupp ps, customer c, supplier s"}, - {"orders o, customer c, supplier s, partsupp ps"}, - {"orders o, customer c, partsupp ps, supplier s"} - }; - } - @BeforeClass public void setUp() { @@ -114,8 +74,7 @@ public void setUp() ImmutableList.of(), ImmutableMap.of( JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name(), - JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.AUTOMATIC.name(), - HANDLE_COMPLEX_EQUI_JOINS, "true"), + JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.AUTOMATIC.name()), Optional.of(4)); this.functionResolution = new FunctionResolution(tester.getMetadata().getFunctionAndTypeManager().getFunctionAndTypeResolver()); } @@ -591,119 +550,6 @@ public void testReorderAndReplicate() values(ImmutableMap.of("A1", 0)))); } - /** - * This test asserts that join re-ordering works as expected for complex equi join clauses ('s.acctbal = c.acctbal + o.totalprice') - * and works irrespective of the order in which tables are specified in the FROM clause - * - * @param tableSpecificationOrder The table specification order - */ - @Test(dataProvider = "tableSpecificationPermutations") - public void testComplexEquiJoinCriteria(String tableSpecificationOrder) - { - // For a full connected join graph, we don't see any CrossJoins - String query = "select 1 from " + tableSpecificationOrder + " where s.suppkey = ps.suppkey and c.custkey = o.custkey and s.acctbal = c.acctbal + o.totalprice"; - PlanMatchPattern expectedPlan = - anyTree( - join(INNER, - ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), - anyTree(tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey"))), - anyTree( - join(INNER, - ImmutableList.of(equiJoinClause("SUM", "S_ACCTBAL")), - anyTree( - project(ImmutableMap.of("SUM", expression("C_ACCTBAL + O_TOTALPRICE")), - join(INNER, - ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), - anyTree( - tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice"))), - anyTree( - tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))))), - anyTree( - tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey"))))))); - assertPlan(query, expectedPlan); - - // The plan is identical to the plan for the fully spelled out version of the Join - String fullQuery = "select 1 from (supplier s inner join partsupp ps on s.suppkey = ps.suppkey) inner join (orders o inner join customer c on c.custkey = o.custkey) " + - " on s.acctbal = c.acctbal + o.totalprice"; - assertPlan(fullQuery, expectedPlan); - } - - @Test - public void testComplexEquiJoinCriteriaForDisjointGraphs() - { - // If the join clause is written with the Left/Right side referring to both sides of a Join node, an equi-join condition cannot be inferred - // and the join space is broken up. Hence, we observe a CrossJoin node - assertPlan("select 1 from supplier s, partsupp ps, customer c, orders o where s.suppkey = ps.suppkey and c.custkey = o.custkey and s.acctbal - c.acctbal = o.totalprice", - anyTree( - join(INNER, - ImmutableList.of(equiJoinClause("C_CUSTKEY", "O_CUSTKEY"), equiJoinClause("SUBTRACT", "O_TOTALPRICE")), - anyTree( - project(ImmutableMap.of("SUBTRACT", expression("S_ACCTBAL - C_ACCTBAL")), - join(INNER, - ImmutableList.of(), //CrossJoin - join(INNER, - ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), - anyTree(tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey"))), - anyTree( - tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey")))), - anyTree( - tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))))), - anyTree( - tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice")))))); - - // The table specification order determines the join order for such cases - // With the below table specification order, the planner adds the complex equi-join condition as a FilterNode on top of a JoinNode - assertPlan("select 1 from orders o, customer c, supplier s, partsupp ps where s.suppkey = ps.suppkey and c.custkey = o.custkey and s.acctbal - c.acctbal = o.totalprice", - anyTree( - join(INNER, - ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), - anyTree( - tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey"))), - anyTree( - filter("O_TOTALPRICE = S_ACCTBAL - C_ACCTBAL", - join(INNER, - ImmutableList.of(), //CrossJoin - join(INNER, - ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), - anyTree(tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice"))), - anyTree( - tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))), - anyTree( - tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey"))))))))); - - // For sub-graphs that are fully connected, join-reordering works with complex predicates as expected - // The rest of the join graph is connected using a CrossJoin - assertPlan("select 1 " + - "from orders o, customer c, supplier s, partsupp ps, part p " + - "where s.suppkey = ps.suppkey " + - " and c.custkey = o.custkey " + - " and s.acctbal = c.acctbal + o.totalprice" + - " and ps.partkey - p.partkey = 0 ", - anyTree( - filter("PS_PARTKEY - P_PARTKEY = 0", - join(INNER, - ImmutableList.of(), // CrossJoin - join(INNER, - ImmutableList.of(equiJoinClause("PS_SUPPKEY", "S_SUPPKEY")), - anyTree( - tableScan("partsupp", ImmutableMap.of("PS_SUPPKEY", "suppkey", "PS_PARTKEY", "partkey"))), - anyTree( - join(INNER, - ImmutableList.of(equiJoinClause("SUM", "S_ACCTBAL")), - anyTree( - project(ImmutableMap.of("SUM", expression("C_ACCTBAL + O_TOTALPRICE")), - join(INNER, - ImmutableList.of(equiJoinClause("O_CUSTKEY", "C_CUSTKEY")), - anyTree( - tableScan("orders", ImmutableMap.of("O_CUSTKEY", "custkey", "O_TOTALPRICE", "totalprice"))), - anyTree( - tableScan("customer", ImmutableMap.of("C_CUSTKEY", "custkey", "C_ACCTBAL", "acctbal")))))), - anyTree( - tableScan("supplier", ImmutableMap.of("S_ACCTBAL", "acctbal", "S_SUPPKEY", "suppkey")))))), - anyTree( - tableScan("part", ImmutableMap.of("P_PARTKEY", "partkey"))))))); - } - private RuleAssert assertReorderJoins() { return tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), tester.getMetadata())); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java index 374310940b85..369b0277abe6 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestJoinQueries.java @@ -24,7 +24,6 @@ import com.google.common.collect.Iterables; import org.testng.annotations.Test; -import static com.facebook.presto.SystemSessionProperties.HANDLE_COMPLEX_EQUI_JOINS; import static com.facebook.presto.SystemSessionProperties.JOINS_NOT_NULL_INFERENCE_STRATEGY; import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; @@ -471,29 +470,6 @@ public void testJoinWithComplexExpressions3() "SELECT SUM(custkey) FROM lineitem JOIN orders ON lineitem.orderkey + 1 = orders.orderkey + 1", // H2 takes a million years because it can't join efficiently on a non-indexed field/expression "SELECT SUM(custkey) FROM lineitem JOIN orders ON lineitem.orderkey = orders.orderkey "); - - Session handleComplexEquiJoins = Session.builder(getSession()) - .setSystemProperty(HANDLE_COMPLEX_EQUI_JOINS, "true") - .build(); - - assertQueryWithSameQueryRunner( - handleComplexEquiJoins, - "select c.custkey, ps.partkey, s.suppkey, o.orderkey " + - "from customer c, " + - " partsupp ps, " + - " orders o, " + - " supplier s " + - "where s.suppkey = ps.suppkey " + - " and c.custkey = o.custkey " + - " and s.nationkey + ps.partkey = c.nationkey " + - "order by c.custkey, ps.partkey, s.suppkey, o.orderkey", - noJoinReordering(), - "select c.custkey, ps.partkey, s.suppkey, o.orderkey " + - "from (customer c inner join orders o ON c.custkey = o.custkey) " + - " inner join " + - " (partsupp ps inner join supplier s ON s.suppkey = ps.suppkey) " + - " on s.nationkey + ps.partkey = c.nationkey " + - "order by c.custkey, ps.partkey, s.suppkey, o.orderkey"); } @Test diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java index 8baa5865443a..8a9eeb580d78 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueryFramework.java @@ -169,12 +169,6 @@ protected void assertQueryWithSameQueryRunner(Session session, @Language("SQL") QueryAssertions.assertQuery(queryRunner, session, actual, queryRunner, expected, false, false); } - protected void assertQueryWithSameQueryRunner(Session actualSession, @Language("SQL") String actual, Session expectedSession, @Language("SQL") String expected) - { - checkArgument(!actual.equals(expected)); - QueryAssertions.assertQuery(queryRunner, actualSession, actual, queryRunner, expectedSession, expected, false, false); - } - protected void assertQuery(Session session, @Language("SQL") String actual, @Language("SQL") String expected) { QueryAssertions.assertQuery(queryRunner, session, actual, expectedQueryRunner, expected, false, false);