Skip to content

Commit

Permalink
fix(core): fix the StatementAnalyzer for UNNEST and LATERAL query (#972)
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal authored Dec 9, 2024
1 parent 9399b5d commit b221920
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ protected Void visitDereferenceExpression(DereferenceExpression node, Void conte
{
QualifiedName qualifiedName = getQualifiedName(node);
if (qualifiedName != null) {
scope.getRelationType().resolveAnyField(qualifiedName)
scope.resolveAnyField(qualifiedName)
.ifPresent(field -> referenceFields.put(NodeRef.of(node), field));
}
else {
Expand All @@ -95,7 +95,7 @@ protected Void visitDereferenceExpression(DereferenceExpression node, Void conte
protected Void visitIdentifier(Identifier node, Void context)
{
QualifiedName qualifiedName = QualifiedName.of(ImmutableList.of(node));
scope.getRelationType().resolveAnyField(qualifiedName)
scope.resolveAnyField(qualifiedName)
.ifPresent(field -> referenceFields.put(NodeRef.of(node), field));
return null;
}
Expand All @@ -104,7 +104,7 @@ protected Void visitIdentifier(Identifier node, Void context)
protected Void visitSubscriptExpression(SubscriptExpression node, Void context)
{
QualifiedName qualifiedName = getQualifiedName(node.getBase());
scope.getRelationType().resolveAnyField(qualifiedName)
scope.resolveAnyField(qualifiedName)
.ifPresent(field -> referenceFields.put(NodeRef.of(node), field));
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@

package io.wren.base.sqlrewrite.analyzer;

import com.google.common.collect.ImmutableList;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.WithQuery;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand Down Expand Up @@ -78,6 +82,29 @@ public Optional<WithQuery> getNamedQuery(String name)
return Optional.empty();
}

/**
* get the columns matching the specified name in the current scope and all parent scopes
*/
public List<Field> resolveFields(QualifiedName name)
{
List<Field> fields = new ArrayList<>(relationType.resolveFields(name));
parent.ifPresent(scope -> fields.addAll(scope.resolveFields(name)));
return ImmutableList.copyOf(fields);
}

/**
* get the columns matching the specified name in the current scope and all parent scopes
*/
public Optional<Field> resolveAnyField(QualifiedName name)
{
return relationType.resolveAnyField(name).or(() -> {
if (parent.isPresent()) {
return parent.get().resolveAnyField(name);
}
return Optional.empty();
});
}

public static Builder builder()
{
return new Builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.trino.sql.tree.JoinCriteria;
import io.trino.sql.tree.JoinOn;
import io.trino.sql.tree.JoinUsing;
import io.trino.sql.tree.Lateral;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.NaturalJoin;
import io.trino.sql.tree.Node;
Expand All @@ -37,6 +38,7 @@
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.Query;
import io.trino.sql.tree.QuerySpecification;
import io.trino.sql.tree.Relation;
import io.trino.sql.tree.SelectItem;
import io.trino.sql.tree.SetOperation;
import io.trino.sql.tree.SingleColumn;
Expand Down Expand Up @@ -253,7 +255,7 @@ private List<Field> createScopeForQuery(Query query, QualifiedName scopeName, Op
.or(() -> Optional.ofNullable(QueryUtil.getQualifiedName(singleColumn.getExpression())).map(QualifiedName::getSuffix))
.orElse(singleColumn.getExpression().toString());
if (scope.isPresent()) {
Optional<Field> fieldOptional = scope.get().getRelationType().resolveAnyField(QueryUtil.getQualifiedName(singleColumn.getExpression()));
Optional<Field> fieldOptional = scope.get().resolveAnyField(QueryUtil.getQualifiedName(singleColumn.getExpression()));
if (fieldOptional.isPresent()) {
Field f = fieldOptional.get();
fields.add(Field.builder()
Expand Down Expand Up @@ -456,6 +458,7 @@ protected Scope visitValues(Values node, Optional<Scope> scope)
@Override
protected Scope visitUnnest(Unnest node, Optional<Scope> scope)
{
scope.ifPresent(s -> node.getExpressions().forEach(e -> analyzeExpression(e, s)));
// TODO: output scope here isn't right
return Scope.builder().parent(scope).build();
}
Expand Down Expand Up @@ -530,7 +533,13 @@ protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
protected Scope visitJoin(Join node, Optional<Scope> scope)
{
Scope leftScope = process(node.getLeft(), scope);
Scope rightScope = process(node.getRight(), scope);
Scope rightScope;
if (isUnnestOrLateral(node.getRight())) {
rightScope = process(node.getRight(), Optional.of(leftScope));
}
else {
rightScope = process(node.getRight(), scope);
}
RelationType relationType = leftScope.getRelationType().joinWith(rightScope.getRelationType());
Scope outputScope = createAndAssignScope(node, scope, relationType);

Expand All @@ -556,6 +565,22 @@ protected Scope visitJoin(Join node, Optional<Scope> scope)
return createAndAssignScope(node, scope, relationType);
}

protected boolean isUnnestOrLateral(Relation relation)
{
return switch (relation) {
case Unnest ignored -> true;
case Lateral ignored -> true;
case AliasedRelation aliasedRelation -> isUnnestOrLateral(aliasedRelation.getRelation());
default -> false;
};
}

@Override
protected Scope visitLateral(Lateral node, Optional<Scope> scope)
{
return process(node.getQuery(), scope);
}

@Override
protected Scope visitAliasedRelation(AliasedRelation relation, Optional<Scope> scope)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ public ExpressionSourceAnalyzer(Scope scope)
@Override
protected Void visitIdentifier(Identifier node, Void context)
{
scope.getRelationType().resolveFields(QualifiedName.of(node.getValue()))
scope.resolveFields(QualifiedName.of(node.getValue()))
.stream().filter(field -> field.getSourceDatasetName().isPresent())
.forEach(field -> exprSources.add(new ExprSource(node.getValue(), field.getSourceDatasetName().get(), field.getSourceColumn().map(Column::getName).orElse(null), node.getLocation().orElse(null))));
return null;
Expand All @@ -270,7 +270,7 @@ protected Void visitIdentifier(Identifier node, Void context)
protected Void visitDereferenceExpression(DereferenceExpression node, Void context)
{
Optional.ofNullable(getQualifiedName(node)).ifPresent(qualifiedName ->
scope.getRelationType().resolveFields(qualifiedName)
scope.resolveFields(qualifiedName)
.stream().filter(field -> field.getSourceDatasetName().isPresent())
.forEach(field -> exprSources.add(new ExprSource(qualifiedName.toString(), field.getSourceDatasetName().get(), field.getSourceColumn().map(Column::getName).orElse(null), node.getLocation().orElse(null)))));
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,41 @@ public void testQueryNestedType()
testDefault = preview(previewDto);
assertThat(testDefault.getColumns().get(0).getType()).isEqualTo("VARCHAR");
assertThat(testDefault.getData().get(0)[0]).isEqualTo("2");

setDuckDBInitSQL("create table nested_table as select * from (values ([1,2,3])) t(a1)");

previewDto = new PreviewDto(manifest, "select * from nested_table n, unnest(n.a1)", null);
testDefault = preview(previewDto);
assertThat(testDefault.getColumns().get(0).getType()).isEqualTo("INTEGER[]");
assertThat(testDefault.getData().get(0).length).isEqualTo(2);
assertThat(testDefault.getData().size()).isEqualTo(3);

previewDto = new PreviewDto(manifest, "select * from nested_table n, unnest(n.a1) u(a1)", null);
testDefault = preview(previewDto);
assertThat(testDefault.getColumns().get(0).getType()).isEqualTo("INTEGER[]");
assertThat(testDefault.getData().get(0).length).isEqualTo(2);
assertThat(testDefault.getData().size()).isEqualTo(3);

previewDto = new PreviewDto(manifest, "select u.a1 from nested_table n, unnest(n.a1) u(a1)", null);
testDefault = preview(previewDto);
assertThat(testDefault.getColumns().get(0).getType()).isEqualTo("INTEGER");
assertThat(testDefault.getData().get(0)[0]).isEqualTo(1);
assertThat(testDefault.getData().get(1)[0]).isEqualTo(2);
assertThat(testDefault.getData().get(2)[0]).isEqualTo(3);

previewDto = new PreviewDto(manifest, "select * from nested_table n cross join lateral (select n.a1[1])", null);
testDefault = preview(previewDto);
assertThat(testDefault.getData().get(0).length).isEqualTo(2);
assertThat(testDefault.getData().size()).isEqualTo(1);

previewDto = new PreviewDto(manifest, "select * from nested_table n cross join lateral (select n.a1[1]) l(a1)", null);
testDefault = preview(previewDto);
assertThat(testDefault.getData().get(0).length).isEqualTo(2);
assertThat(testDefault.getData().size()).isEqualTo(1);

previewDto = new PreviewDto(manifest, "select l.a1 from nested_table n cross join lateral (select n.a1[1]) l(a1)", null);
testDefault = preview(previewDto);
assertThat(testDefault.getColumns().get(0).getType()).isEqualTo("INTEGER");
assertThat(testDefault.getData().get(0)[0]).isEqualTo(1);
}
}

0 comments on commit b221920

Please sign in to comment.