Skip to content

Commit

Permalink
Support set operations and subquery expression
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Mar 26, 2024
1 parent 1f53aed commit 8bed5a4
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public class Analysis
private final Set<CumulativeMetric> cumulativeMetrics = new HashSet<>();
private final Set<View> views = new HashSet<>();
private final Multimap<CatalogSchemaTableName, String> collectedColumns = HashMultimap.create();
private final Map<NodeRef<Expression>, Field> referenceFields = new HashMap<>();
private final List<SimplePredicate> simplePredicates = new ArrayList<>();

private final Set<Node> requiredSourceNodes = new HashSet<>();
Expand Down Expand Up @@ -174,6 +175,16 @@ public Multimap<CatalogSchemaTableName, String> getCollectedColumns()
return collectedColumns;
}

public void addReferenceFields(Map<NodeRef<Expression>, Field> referenceFields)
{
this.referenceFields.putAll(referenceFields);
}

public Map<NodeRef<Expression>, Field> getReferenceFields()
{
return referenceFields;
}

void addTypeCoercion(NodeRef<Node> nodeRef, Node node)
{
typeCoercionMap.put(nodeRef, node);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SubqueryExpression;
import io.wren.base.SessionContext;
import io.wren.base.WrenMDL;

import java.util.ArrayList;
import java.util.HashMap;
Expand All @@ -36,9 +39,9 @@ public class ExpressionAnalyzer
{
private ExpressionAnalyzer() {}

public static ExpressionAnalysis analyze(Scope scope, Expression expression)
public static ExpressionAnalysis analyze(Scope scope, Expression expression, SessionContext sessionContext, WrenMDL wrenMDL, Analysis analysis)
{
ExpressionVisitor visitor = new ExpressionVisitor(scope);
ExpressionVisitor visitor = new ExpressionVisitor(scope, sessionContext, wrenMDL, analysis);
visitor.process(expression);

return new ExpressionAnalysis(visitor.getReferenceFields(), visitor.getPredicates(), visitor.isRequireRelation());
Expand All @@ -48,13 +51,23 @@ private static class ExpressionVisitor
extends DefaultTraversalVisitor<Void>
{
private final Scope scope;
private final WrenMDL wrenMDL;
private final SessionContext sessionContext;
private final Analysis analysis;
private final Map<NodeRef<Expression>, Field> referenceFields = new HashMap<>();
private final List<ComparisonExpression> predicates = new ArrayList<>();
private boolean requireRelation;

public ExpressionVisitor(Scope scope)
public ExpressionVisitor(
Scope scope,
SessionContext sessionContext,
WrenMDL wrenMDL,
Analysis analysis)
{
this.scope = requireNonNull(scope);
this.scope = requireNonNull(scope, "scope is null");
this.sessionContext = requireNonNull(sessionContext, "sessionContext is null");
this.wrenMDL = requireNonNull(wrenMDL, "wrenMDL is null");
this.analysis = requireNonNull(analysis, "analysis is null");
}

@Override
Expand Down Expand Up @@ -101,6 +114,13 @@ protected Void visitFunctionCall(FunctionCall node, Void context)
return null;
}

@Override
protected Void visitSubqueryExpression(SubqueryExpression node, Void context)
{
StatementAnalyzer.analyze(analysis, node.getQuery(), sessionContext, wrenMDL);
return null;
}

public Map<NodeRef<Expression>, Field> getReferenceFields()
{
return referenceFields;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.wren.base.sqlrewrite.analyzer;

import io.trino.sql.tree.Node;
import io.wren.base.ErrorCodeSupplier;
import io.wren.base.Location;
import io.wren.base.WrenException;

import java.util.Optional;

import static java.lang.String.format;

public final class SemanticExceptions
{
private SemanticExceptions() {}

public static WrenException semanticException(ErrorCodeSupplier code, Node node, String format, Object... args)
{
return semanticException(code, node, null, format, args);
}

public static WrenException semanticException(ErrorCodeSupplier code, Node node, Throwable cause, String format, Object... args)
{
throw new WrenException(code, extractLocation(node), format(format, args), cause);
}

public static Optional<Location> extractLocation(Node node)
{
return node.getLocation()
.map(location -> new Location(location.getLineNumber(), location.getColumnNumber()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
import io.trino.sql.tree.Query;
import io.trino.sql.tree.QuerySpecification;
import io.trino.sql.tree.SelectItem;
import io.trino.sql.tree.SetOperation;
import io.trino.sql.tree.SingleColumn;
import io.trino.sql.tree.Statement;
import io.trino.sql.tree.Table;
import io.trino.sql.tree.TableSubquery;
import io.trino.sql.tree.Union;
import io.trino.sql.tree.Unnest;
import io.trino.sql.tree.Values;
import io.trino.sql.tree.With;
Expand All @@ -61,11 +61,14 @@
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.sql.QueryUtil.getQualifiedName;
import static io.wren.base.metadata.StandardErrorCode.TYPE_MISMATCH;
import static io.wren.base.sqlrewrite.Utils.toCatalogSchemaTableName;
import static io.wren.base.sqlrewrite.analyzer.Analysis.SimplePredicate;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toUnmodifiableSet;

Expand Down Expand Up @@ -314,7 +317,7 @@ protected Scope visitQuerySpecification(QuerySpecification node, Optional<Scope>
Scope sourceScope = analyzeFrom(node, scope);
List<Expression> outputExpressions = analyzeSelect(node, sourceScope);
node.getWhere().ifPresent(where -> analyzeWhere(where, sourceScope));
node.getHaving().ifPresent(having -> ExpressionAnalyzer.analyze(sourceScope, having));
node.getHaving().ifPresent(having -> analyzeExpression(having, sourceScope));
node.getLimit().ifPresent(limit -> analysis.setLimit(((Limit) limit).getRowCount()));
node.getOrderBy().ifPresent(orderBy -> orderBy.getSortItems()
.forEach(item -> {
Expand All @@ -328,6 +331,7 @@ protected Scope visitQuerySpecification(QuerySpecification node, Optional<Scope>
}
analysis.addSortItem(new Analysis.SortItemAnalysis(name, item.getOrdering().name()));
}));
// TODO: this scope is wrong.
return createAndAssignScope(node, scope, sourceScope);
}

Expand Down Expand Up @@ -366,8 +370,7 @@ private void analyzeSelectSingleColumn(SingleColumn singleColumn, Scope scope, I
{
outputExpressions.add(singleColumn.getAlias().map(name -> (Expression) name).orElse(singleColumn.getExpression()));
// TODO: handle when singleColumn is a subquery
ExpressionAnalysis expressionAnalysis = ExpressionAnalyzer.analyze(scope, singleColumn.getExpression());
analysis.addCollectedColumns(expressionAnalysis.getCollectedFields());
ExpressionAnalysis expressionAnalysis = analyzeExpression(singleColumn.getExpression(), scope);

if (expressionAnalysis.isRequireRelation()) {
analysis.addRequiredSourceNode(scope.getRelationId().getSourceNode()
Expand All @@ -390,8 +393,7 @@ private Scope analyzeFrom(QuerySpecification node, Optional<Scope> scope)

private void analyzeWhere(Expression node, Scope scope)
{
ExpressionAnalysis expressionAnalysis = ExpressionAnalyzer.analyze(scope, node);
analysis.addCollectedColumns(expressionAnalysis.getCollectedFields());
ExpressionAnalysis expressionAnalysis = analyzeExpression(node, scope);
Map<NodeRef<Expression>, Field> fields = expressionAnalysis.getReferencedFields();
expressionAnalysis.getPredicates().stream()
.filter(PredicateMatcher.PREDICATE_MATCHER::shapeMatches)
Expand Down Expand Up @@ -452,10 +454,30 @@ protected Scope visitFunctionRelation(FunctionRelation node, Optional<Scope> sco
}

@Override
protected Scope visitUnion(Union node, Optional<Scope> scope)
protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
{
// TODO: output scope here isn't right
return Scope.builder().parent(scope).build();
checkState(node.getRelations().size() >= 2);
List<RelationType> relationTypes = node.getRelations().stream()
.map(relation -> process(relation, scope).getRelationType()).collect(toImmutableList());
String setOperationName = node.getClass().getSimpleName().toUpperCase(ENGLISH);
List<Field> outputFields = relationTypes.get(0).getFields();
for (RelationType relationType : relationTypes) {
int outputFieldSize = outputFields.size();
int descFieldSize = relationType.getFields().size();
if (outputFieldSize != descFieldSize) {
throw SemanticExceptions.semanticException(
TYPE_MISMATCH,
node,
"%s query has different number of fields: %d, %d",
setOperationName,
outputFieldSize,
descFieldSize);
}

// TODO: check type compatibility
}

return createAndAssignScope(node, scope, new RelationType(outputFields));
}

@Override
Expand All @@ -467,10 +489,10 @@ protected Scope visitJoin(Join node, Optional<Scope> scope)
Scope outputScope = createAndAssignScope(node, scope, relationType);

JoinCriteria criteria = node.getCriteria().orElse(null);
// TODO: handle other join types
// TODO: handle other join type
if (criteria instanceof JoinOn) {
Expression expression = ((JoinOn) criteria).getExpression();
analysis.addCollectedColumns(ExpressionAnalyzer.analyze(outputScope, expression).getCollectedFields());
analyzeExpression(expression, outputScope);
typeCoercionOptional.ifPresent(typeCoercion -> {
Optional<Expression> coerced = typeCoercion.coerceExpression(expression, outputScope);
if (coerced.isPresent()) {
Expand Down Expand Up @@ -565,5 +587,13 @@ private Scope createAndAssignScope(Node node, Optional<Scope> parentScope, Scope
analysis.setScope(node, newScope);
return newScope;
}

private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope)
{
ExpressionAnalysis expressionAnalysis = ExpressionAnalyzer.analyze(scope, expression, sessionContext, wrenMDL, analysis);
analysis.addCollectedColumns(expressionAnalysis.getCollectedFields());
analysis.addReferenceFields(expressionAnalysis.getReferencedFields());
return expressionAnalysis;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL;
import static io.wren.base.sqlrewrite.Utils.SQL_PARSER;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public class TestAllRulesRewrite
extends AbstractTestFramework
Expand Down Expand Up @@ -116,6 +117,16 @@ public Object[][] wrenUsedCases()
"values('Gusare', 2560), ('HisoHiso Banashi', 1500), ('Dakara boku wa ongaku o yameta', 2553)"},
{"select band, cast(price as integer) from useMetric order by band", "values ('Yorushika', 2553), ('ZUTOMAYO', 4060)"},
{"select * from \"Order\"", "values (1, 1), (2, 1), (3, 2), (4, 3)"},
{"select name, price from Album where id in (select albumId from \"Order\")",
"values('Gusare', 2560), ('HisoHiso Banashi', 1500), ('Dakara boku wa ongaku o yameta', 2553)"},
{"select name, price from Album where id not in (select albumId from \"Order\")",
"values(1, 1) limit 0"},
{"select * from (select name ,price from Album where bandId = 1 union select name, price from Album where bandId = 2) order by price",
"values('HisoHiso Banashi', 1500), ('Dakara boku wa ongaku o yameta', 2553), ('Gusare', 2560)"},
{"select * from (select name ,price from Album where bandId = 1 except select name, price from Album where bandId = 2) order by price",
"values('HisoHiso Banashi', 1500), ('Gusare', 2560)"},
{"select * from (select name ,price from Album where bandId = 1 intersect select name, price from Album where bandId = 2) order by price",
"values(1, 1) limit 0"}
};
}

Expand Down Expand Up @@ -149,6 +160,14 @@ public void testWrenNoRewrite(String original)
assertThat(rewrite(original)).isEqualTo(formatSql(expectedState));
}

// TODO: The scope of QuerySpecification is wrong. Enable it after fixing the scope.
@Test(enabled = false)
public void testSetOperationColumnNoMatch()
{
assertThatThrownBy(() -> rewrite("select name, price from Album union select price from Album"))
.hasMessageFindingMatch("query has different number of fields: expected 2, found 1");
}

private String rewrite(String sql)
{
return WrenPlanner.rewrite(sql, DEFAULT_SESSION_CONTEXT, new AnalyzedMDL(wrenMDL, null));
Expand Down

0 comments on commit 8bed5a4

Please sign in to comment.