Skip to content

Commit

Permalink
Enhance the analysis API for the location of Join Criteria and source…
Browse files Browse the repository at this point in the history
… columns (#714)

* add node location for join criteria

* add source column for all expr sources
  • Loading branch information
goldmedal authored Jul 25, 2024
1 parent f6c1cbf commit 8449442
Show file tree
Hide file tree
Showing 12 changed files with 307 additions and 85 deletions.
27 changes: 21 additions & 6 deletions ibis-server/tests/routers/v2/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_analysis_sql_select_all_customer():
"expression": "custkey",
"nodeLocation": {"line": 1, "column": 8},
"sourceDataset": "customer",
"sourceColumn": "custkey",
},
]
assert result[0]["selectItems"][1]["nodeLocation"] == {"line": 1, "column": 8}
Expand Down Expand Up @@ -132,6 +133,7 @@ def test_analysis_sql_group_by_customer():
"expression": "custkey",
"nodeLocation": {"line": 1, "column": 8},
"sourceDataset": "customer",
"sourceColumn": "custkey",
},
],
}
Expand All @@ -154,17 +156,25 @@ def test_analysis_sql_join_customer_orders():
assert result[0]["relation"]["left"]["nodeLocation"] == {"line": 1, "column": 15}
assert result[0]["relation"]["right"]["type"] == "TABLE"
assert result[0]["relation"]["right"]["nodeLocation"] == {"line": 1, "column": 31}
assert result[0]["relation"]["criteria"] == "ON (c.custkey = o.custkey)"
assert (
result[0]["relation"]["criteria"]["expression"] == "ON (c.custkey = o.custkey)"
)
assert result[0]["relation"]["criteria"]["nodeLocation"] == {
"line": 1,
"column": 43,
}
assert result[0]["relation"]["exprSources"] == [
{
"expression": "o.custkey",
"nodeLocation": {"line": 1, "column": 55},
"sourceDataset": "orders",
},
{
"expression": "c.custkey",
"nodeLocation": {"line": 1, "column": 43},
"sourceDataset": "customer",
"sourceColumn": "custkey",
},
{
"expression": "o.custkey",
"nodeLocation": {"line": 1, "column": 55},
"sourceDataset": "orders",
"sourceColumn": "custkey",
},
]

Expand All @@ -188,6 +198,7 @@ def test_analysis_sql_where_clause():
"expression": "custkey",
"nodeLocation": {"line": 1, "column": 30},
"sourceDataset": "customer",
"sourceColumn": "custkey",
},
]
assert result[0]["filter"]["right"]["type"] == "AND"
Expand All @@ -210,6 +221,7 @@ def test_analysis_sql_group_by_multiple_columns():
"expression": "custkey",
"nodeLocation": {"line": 1, "column": 8},
"sourceDataset": "customer",
"sourceColumn": "custkey",
},
],
}
Expand All @@ -221,6 +233,7 @@ def test_analysis_sql_group_by_multiple_columns():
"expression": "name",
"nodeLocation": {"line": 1, "column": 27},
"sourceDataset": "customer",
"sourceColumn": "name",
},
],
}
Expand All @@ -232,6 +245,7 @@ def test_analysis_sql_group_by_multiple_columns():
"expression": "nationkey",
"nodeLocation": {"line": 1, "column": 61},
"sourceDataset": "customer",
"sourceColumn": "nationkey",
},
],
}
Expand All @@ -253,6 +267,7 @@ def test_analysis_sql_order_by():
"expression": "custkey",
"nodeLocation": {"line": 1, "column": 8},
"sourceDataset": "customer",
"sourceColumn": "custkey",
},
]
assert result[0]["sortings"][0]["nodeLocation"] == {"line": 1, "column": 45}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,23 @@ public class Field
private final CatalogSchemaTableName tableName;
private final String columnName;
private final Optional<String> sourceDatasetName;
private final Optional<String> sourceColumnName;
private final Optional<String> name;

private Field(
QualifiedName relationAlias,
CatalogSchemaTableName tableName,
String columnName,
String name,
String sourceDatasetName)
String sourceDatasetName,
String sourceColumnName)
{
this.relationAlias = Optional.ofNullable(relationAlias);
this.tableName = requireNonNull(tableName, "modelName is null");
this.columnName = requireNonNull(columnName, "columnName is null");
this.name = Optional.ofNullable(name);
this.sourceDatasetName = Optional.ofNullable(sourceDatasetName);
this.sourceColumnName = Optional.ofNullable(sourceColumnName);
}

public Optional<QualifiedName> getRelationAlias()
Expand Down Expand Up @@ -74,6 +77,11 @@ public Optional<String> getSourceDatasetName()
return sourceDatasetName;
}

public Optional<String> getSourceColumnName()
{
return sourceColumnName;
}

public boolean matchesPrefix(Optional<QualifiedName> prefix)
{
return prefix.isEmpty() || relationAlias.orElse(Utils.toQualifiedName(tableName)).hasSuffix(prefix.get());
Expand Down Expand Up @@ -120,6 +128,7 @@ public String toString()
", columnName='" + columnName + '\'' +
", name=" + name +
", sourceDatasetName=" + sourceDatasetName +
", sourceColumnName=" + sourceColumnName +
'}';
}

Expand All @@ -135,6 +144,7 @@ public static class Builder
private String columnName;
private String name;
private String sourceModelName;
private String sourceColumnName;

public Builder() {}

Expand All @@ -145,6 +155,7 @@ public Builder like(Field field)
this.columnName = field.columnName;
this.name = field.name.orElse(null);
this.sourceModelName = field.sourceDatasetName.orElse(null);
this.sourceColumnName = field.sourceColumnName.orElse(null);
return this;
}

Expand Down Expand Up @@ -178,9 +189,15 @@ public Builder sourceModelName(String sourceModelName)
return this;
}

public Builder sourceColumnName(String sourceColumnName)
{
this.sourceColumnName = sourceColumnName;
return this;
}

public Field build()
{
return new Field(relationAlias, tableName, columnName, name, sourceModelName);
return new Field(relationAlias, tableName, columnName, name, sourceModelName, sourceColumnName);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ private List<Field> createScopeForQuery(Query query, QualifiedName scopeName, Op
.name(f.getName().orElse(f.getColumnName()))
.tableName(toCatalogSchemaTableName(sessionContext, scopeName))
.sourceModelName(f.getSourceDatasetName().orElse(null))
.sourceColumnName(f.getSourceColumnName().orElse(null))
.build())));
}
else {
Expand All @@ -253,6 +254,7 @@ private List<Field> createScopeForQuery(Query query, QualifiedName scopeName, Op
.name(name)
.tableName(toCatalogSchemaTableName(sessionContext, scopeName))
.sourceModelName(f.getSourceDatasetName().orElse(null))
.sourceColumnName(f.getSourceColumnName().orElse(null))
.build());
continue;
}
Expand Down Expand Up @@ -280,6 +282,7 @@ private List<Field> collectFieldFromMDL(CatalogSchemaTableName tableName)
.columnName(column.getName())
.name(column.getName())
.sourceModelName(tableName.getSchemaTableName().getTableName())
.sourceColumnName(column.getName())
.build())
.collect(toImmutableList());
}
Expand All @@ -293,6 +296,7 @@ else if (wrenMDL.getMetric(tableName.getSchemaTableName().getTableName()).isPres
.columnName(column.getName())
.name(column.getName())
.sourceModelName(tableName.getSchemaTableName().getTableName())
.sourceColumnName(column.getName())
.build())
.collect(toImmutableList());
}
Expand All @@ -304,12 +308,14 @@ else if (wrenMDL.getCumulativeMetric(tableName.getSchemaTableName().getTableName
.columnName(cumulativeMetric.getWindow().getName())
.name(cumulativeMetric.getWindow().getName())
.sourceModelName(tableName.getSchemaTableName().getTableName())
.sourceColumnName(cumulativeMetric.getWindow().getName())
.build(),
Field.builder()
.tableName(tableName)
.columnName(cumulativeMetric.getMeasure().getName())
.name(cumulativeMetric.getMeasure().getName())
.sourceModelName(tableName.getSchemaTableName().getTableName())
.sourceColumnName(cumulativeMetric.getMeasure().getName())
.build());
}
return ImmutableList.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ protected Void visitAllColumns(AllColumns node, DecisionPointContext decisionPoi
List.of(new ExprSource(
field.getName().orElse(field.getColumnName()),
field.getTableName().getSchemaTableName().getTableName(),
field.getSourceColumnName().orElse(null),
node.getLocation().orElse(null)))));
});
}
Expand All @@ -164,6 +165,7 @@ protected Void visitAllColumns(AllColumns node, DecisionPointContext decisionPoi
List.of(new ExprSource(
field.getName().orElse(field.getColumnName()),
field.getTableName().getSchemaTableName().getTableName(),
field.getSourceColumnName().orElse(null),
node.getLocation().orElse(null)))));
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import java.util.Objects;

public record ExprSource(String expression, String sourceDataset, NodeLocation nodeLocation)
public record ExprSource(String expression, String sourceDataset, String sourceColumn, NodeLocation nodeLocation)
{
@Override
public boolean equals(Object o)
Expand All @@ -34,12 +34,13 @@ public boolean equals(Object o)
ExprSource that = (ExprSource) o;
return Objects.equals(expression, that.expression) &&
Objects.equals(sourceDataset, that.sourceDataset)
&& Objects.equals(sourceColumn, that.sourceColumn)
&& Objects.equals(nodeLocation, that.nodeLocation);
}

@Override
public int hashCode()
{
return Objects.hash(expression, sourceDataset, nodeLocation);
return Objects.hash(expression, sourceDataset, sourceColumn, nodeLocation);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.wren.base.sqlrewrite.analyzer.decisionpoint;

import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.DefaultTraversalVisitor;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeLocation;

import java.util.Optional;

/**
* Try to find the most left-side location of an expression.
* For example, the binary expression "a + b" will return the location of "a".
* The comparison expression "a = b" will return the location of "a".
* The location of the expression itself will be returned if it is not a binary or comparison expression.
*/
public class ExpressionLocationAnalyzer
{
private ExpressionLocationAnalyzer() {}

public static Optional<NodeLocation> analyze(Node node)
{
Visitor visitor = new Visitor();
visitor.process(node, null);
return visitor.nodeLocation;
}

static class Visitor
extends DefaultTraversalVisitor<Void>
{
private Optional<NodeLocation> nodeLocation = Optional.empty();

@Override
protected Void visitExpression(Expression node, Void context)
{
nodeLocation = node.getLocation();
return null;
}

@Override
protected Void visitComparisonExpression(ComparisonExpression node, Void context)
{
nodeLocation = node.getLeft().getLocation();
return null;
}

@Override
protected Void visitArithmeticBinary(ArithmeticBinaryExpression node, Void context)
{
nodeLocation = node.getLeft().getLocation();
return null;
}
}
}
Loading

0 comments on commit 8449442

Please sign in to comment.