Skip to content

Commit

Permalink
add the source analyze for join criteria
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed May 23, 2024
1 parent c6087ef commit a1777e9
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ protected Void visitQuerySpecification(QuerySpecification node, Context ignored)
QueryAnalysis.Builder builder = QueryAnalysis.builder();
Context context = new Context(builder, analysis.getScope(node));
process(node.getSelect(), context);
node.getFrom().ifPresent(from -> builder.setRelation(RelationAnalyzer.analyze(from, sessionContext, mdl)));
node.getFrom().ifPresent(from -> builder.setRelation(RelationAnalyzer.analyze(from, sessionContext, mdl, analysis)));
node.getWhere().ifPresent(where -> builder.setFilter(FilterAnalyzer.analyze(where)));

if (node.getGroupBy().isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
package io.wren.base.sqlrewrite.analyzer.decisionpoint;

import java.util.List;
import java.util.Objects;

import static io.wren.base.Utils.checkArgument;
import static java.util.Objects.requireNonNull;

public abstract class RelationAnalysis
{
static JoinRelation join(Type type, String alias, RelationAnalysis left, RelationAnalysis right, String criteria)
static JoinRelation join(Type type, String alias, RelationAnalysis left, RelationAnalysis right, String criteria, List<ExprSource> exprSources)
{
checkArgument(type != Type.TABLE && type != Type.SUBQUERY, "type should be a join type");
return new JoinRelation(type, alias, left, right, criteria);
return new JoinRelation(type, alias, left, right, criteria, exprSources);
}

static TableRelation table(String tableName, String alias)
Expand Down Expand Up @@ -74,13 +75,15 @@ public static class JoinRelation
private final RelationAnalysis left;
private final RelationAnalysis right;
private final String criteria;
private final List<ExprSource> exprSources;

public JoinRelation(Type type, String alias, RelationAnalysis left, RelationAnalysis right, String criteria)
public JoinRelation(Type type, String alias, RelationAnalysis left, RelationAnalysis right, String criteria, List<ExprSource> exprSources)
{
super(type, alias);
this.left = requireNonNull(left, "left is null");
this.right = requireNonNull(right, "right is null");
this.criteria = criteria;
this.exprSources = exprSources == null ? List.of() : exprSources;
}

public RelationAnalysis getLeft()
Expand All @@ -97,6 +100,11 @@ public String getCriteria()
{
return criteria;
}

public List<ExprSource> getExprSources()
{
return exprSources;
}
}

public static class TableRelation
Expand Down Expand Up @@ -132,4 +140,28 @@ public List<QueryAnalysis> getBody()
return body;
}
}

public record ExprSource(String expression, String sourceDataset)
{
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}

if (o == null || getClass() != o.getClass()) {
return false;
}

ExprSource that = (ExprSource) o;
return Objects.equals(expression, that.expression) && Objects.equals(sourceDataset, that.sourceDataset);
}

@Override
public int hashCode()
{
return Objects.hash(expression, sourceDataset);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,13 @@

package io.wren.base.sqlrewrite.analyzer.decisionpoint;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.sql.tree.AliasedRelation;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.DefaultExpressionTraversalVisitor;
import io.trino.sql.tree.DereferenceExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionRelation;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.Join;
Expand All @@ -25,6 +30,7 @@
import io.trino.sql.tree.Lateral;
import io.trino.sql.tree.NaturalJoin;
import io.trino.sql.tree.PatternRecognitionRelation;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.QuerySpecification;
import io.trino.sql.tree.Relation;
import io.trino.sql.tree.SampledRelation;
Expand All @@ -35,31 +41,39 @@
import io.trino.sql.tree.Values;
import io.wren.base.SessionContext;
import io.wren.base.WrenMDL;
import io.wren.base.sqlrewrite.analyzer.Analysis;
import io.wren.base.sqlrewrite.analyzer.Scope;

import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static io.trino.sql.tree.DereferenceExpression.getQualifiedName;
import static java.lang.String.format;
import static java.util.stream.Collectors.joining;

public class RelationAnalyzer
{
private RelationAnalyzer() {}

public static RelationAnalysis analyze(Relation relation, SessionContext sessionContext, WrenMDL wrenMDL)
public static RelationAnalysis analyze(Relation relation, SessionContext sessionContext, WrenMDL wrenMDL, Analysis analysis)
{
return new Visitor(sessionContext, wrenMDL).process(relation, null);
return new Visitor(sessionContext, wrenMDL, analysis).process(relation, null);
}

static class Visitor
extends AstVisitor<RelationAnalysis, Void>
{
private final SessionContext sessionContext;
private final WrenMDL wrenMDL;
private final Analysis analysis;

public Visitor(SessionContext sessionContext, WrenMDL wrenMDL)
public Visitor(SessionContext sessionContext, WrenMDL wrenMDL, Analysis analysis)
{
this.sessionContext = sessionContext;
this.wrenMDL = wrenMDL;
this.analysis = analysis;
}

@Override
Expand Down Expand Up @@ -109,9 +123,14 @@ protected RelationAnalysis visitJoin(Join node, Void context)
{
RelationAnalysis left = process(node.getLeft(), context);
RelationAnalysis right = process(node.getRight(), context);

Scope scope = analysis.getScope(node);
List<RelationAnalysis.ExprSource> exprSources = node.getCriteria().map(criteria -> analyzeCriteria(criteria, scope))
.orElse(null);
return new RelationAnalysis.JoinRelation(
RelationAnalysis.Type.valueOf(format("%s_JOIN", node.getType())),
null, left, right, node.getCriteria().map(this::formatCriteria).orElse(null));
null, left, right, node.getCriteria().map(this::formatCriteria).orElse(null),
exprSources);
}

private String formatCriteria(JoinCriteria criteria)
Expand All @@ -135,6 +154,24 @@ private String formatCriteria(JoinCriteria criteria)
return builder.toString();
}

private List<RelationAnalysis.ExprSource> analyzeCriteria(JoinCriteria criteria, Scope scope)
{
Set<RelationAnalysis.ExprSource> exprSources = new HashSet<>();
switch (criteria) {
case JoinOn joinOn:
exprSources.addAll(ExpressionSourceAnalyzer.analyze(joinOn.getExpression(), scope));
break;
case JoinUsing joinUsing:
joinUsing.getColumns().forEach(column -> exprSources.addAll(ExpressionSourceAnalyzer.analyze(column, scope)));
break;
case NaturalJoin ignored:
break;
default:
throw new IllegalArgumentException("Unsupported join criteria: " + criteria);
}
return ImmutableList.copyOf(exprSources);
}

@Override
protected RelationAnalysis visitAliasedRelation(AliasedRelation node, Void context)
{
Expand All @@ -143,7 +180,7 @@ protected RelationAnalysis visitAliasedRelation(AliasedRelation node, Void conte
return switch (relationAnalysis) {
case RelationAnalysis.TableRelation tableRelation -> RelationAnalysis.table(tableRelation.getTableName(), node.getAlias().getValue());
case RelationAnalysis.JoinRelation joinRelation ->
RelationAnalysis.join(joinRelation.getType(), node.getAlias().getValue(), joinRelation.getLeft(), joinRelation.getRight(), joinRelation.getCriteria());
RelationAnalysis.join(joinRelation.getType(), node.getAlias().getValue(), joinRelation.getLeft(), joinRelation.getRight(), joinRelation.getCriteria(), joinRelation.getExprSources());
case RelationAnalysis.SubqueryRelation subqueryRelation -> RelationAnalysis.subquery(node.getAlias().getValue(), subqueryRelation.getBody());
default -> throw new IllegalStateException("Unexpected value: " + relationAnalysis);
};
Expand Down Expand Up @@ -177,4 +214,40 @@ protected RelationAnalysis visitLateral(Lateral node, Void context)
throw new UnsupportedOperationException("Analyze Lateral is not supported yet");
}
}

static class ExpressionSourceAnalyzer
extends DefaultExpressionTraversalVisitor<Void>
{
static Set<RelationAnalysis.ExprSource> analyze(Expression expression, Scope scope)
{
ExpressionSourceAnalyzer analyzer = new ExpressionSourceAnalyzer(scope);
analyzer.process(expression, null);
return ImmutableSet.copyOf(analyzer.exprSources);
}

private final Scope scope;
private final Set<RelationAnalysis.ExprSource> exprSources = new HashSet<>();

public ExpressionSourceAnalyzer(Scope scope)
{
this.scope = scope;
}

@Override
protected Void visitIdentifier(Identifier node, Void context)
{
scope.getRelationType().resolveFields(QualifiedName.of(node.getValue()))
.forEach(field -> exprSources.add(new RelationAnalysis.ExprSource(node.getValue(), field.getTableName().getSchemaTableName().getTableName())));
return null;
}

@Override
protected Void visitDereferenceExpression(DereferenceExpression node, Void context)
{
Optional.ofNullable(getQualifiedName(node)).ifPresent(qualifiedName ->
scope.getRelationType().resolveFields(qualifiedName)
.forEach(field -> exprSources.add(new RelationAnalysis.ExprSource(qualifiedName.toString(), field.getTableName().getSchemaTableName().getTableName()))));
return null;
}
}
}
13 changes: 10 additions & 3 deletions wren-main/src/main/java/io/wren/main/web/AnalysisResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,25 @@ private static RelationAnalysisDto toRelationAnalysisDto(RelationAnalysis relati
{
return switch (relationAnalysis) {
case RelationAnalysis.TableRelation tableRelation ->
new RelationAnalysisDto(tableRelation.getType().name(), tableRelation.getAlias(), null, null, null, tableRelation.getTableName(), null);
new RelationAnalysisDto(tableRelation.getType().name(), tableRelation.getAlias(), null, null, null, tableRelation.getTableName(), null, null);
case RelationAnalysis.JoinRelation joinRelation -> new RelationAnalysisDto(
joinRelation.getType().name(),
joinRelation.getAlias(),
toRelationAnalysisDto(joinRelation.getLeft()),
toRelationAnalysisDto(joinRelation.getRight()),
joinRelation.getCriteria(),
null,
null);
null,
joinRelation.getExprSources().stream().map(AnalysisResource::toExprSourceDto).toList());
case RelationAnalysis.SubqueryRelation subqueryRelation -> new RelationAnalysisDto(
subqueryRelation.getType().name(),
subqueryRelation.getAlias(),
null,
null,
null,
null,
subqueryRelation.getBody().stream().map(AnalysisResource::toQueryAnalysisDto).toList());
subqueryRelation.getBody().stream().map(AnalysisResource::toQueryAnalysisDto).toList(),
null);
case null -> null;
default -> throw new IllegalArgumentException("Unsupported relation analysis: " + relationAnalysis);
};
Expand All @@ -139,4 +141,9 @@ private static SortItemAnalysisDto toSortItemAnalysisDto(QueryAnalysis.SortItemA
{
return new SortItemAnalysisDto(sortItemAnalysis.getExpression(), sortItemAnalysis.getOrdering().name());
}

private static QueryAnalysisDto.ExprSourceDto toExprSourceDto(RelationAnalysis.ExprSource exprSource)
{
return new QueryAnalysisDto.ExprSourceDto(exprSource.expression(), exprSource.sourceDataset());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

@JsonInclude(JsonInclude.Include.NON_NULL)
Expand Down Expand Up @@ -115,9 +116,18 @@ public static class RelationAnalysisDto
private String criteria;
private String tableName;
private List<QueryAnalysisDto> body;
private List<ExprSourceDto> exprSources;

@JsonCreator
public RelationAnalysisDto(String type, String alias, RelationAnalysisDto left, RelationAnalysisDto right, String criteria, String tableName, List<QueryAnalysisDto> body)
public RelationAnalysisDto(
String type,
String alias,
RelationAnalysisDto left,
RelationAnalysisDto right,
String criteria,
String tableName,
List<QueryAnalysisDto> body,
List<ExprSourceDto> exprSources)
{
this.type = type;
this.alias = alias;
Expand All @@ -126,6 +136,7 @@ public RelationAnalysisDto(String type, String alias, RelationAnalysisDto left,
this.criteria = criteria;
this.tableName = tableName;
this.body = body;
this.exprSources = exprSources;
}

@JsonProperty
Expand Down Expand Up @@ -169,6 +180,12 @@ public List<QueryAnalysisDto> getBody()
{
return body;
}

@JsonProperty
public List<ExprSourceDto> getExprSources()
{
return exprSources;
}
}

@JsonInclude(JsonInclude.Include.NON_NULL)
Expand Down Expand Up @@ -238,4 +255,48 @@ public String getOrdering()
return ordering;
}
}

public static class ExprSourceDto
{
private String expression;
private String sourceDataset;

@JsonCreator
public ExprSourceDto(String expression, String sourceDataset)
{
this.expression = expression;
this.sourceDataset = sourceDataset;
}

@JsonProperty
public String getExpression()
{
return expression;
}

@JsonProperty
public String getSourceDataset()
{
return sourceDataset;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ExprSourceDto that = (ExprSourceDto) o;
return Objects.equals(expression, that.expression) && Objects.equals(sourceDataset, that.sourceDataset);
}

@Override
public int hashCode()
{
return Objects.hash(expression, sourceDataset);
}
}
}

0 comments on commit a1777e9

Please sign in to comment.