Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(interactive): Fix Bugs in Dynamic Params #3279

Merged
merged 4 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.BasicSqlType;
import org.apache.calcite.sql.type.GraphInferTypes;
import org.apache.calcite.sql.type.IntervalSqlType;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
Expand Down Expand Up @@ -572,6 +573,11 @@
throw new UnsupportedOperationException(
"operator " + operator.getKind().name() + " not supported");
}
// infer unknown operands types from other known types
if (operator.getOperandTypeInference() != GraphInferTypes.RETURN_TYPE) {
operandList =
inferOperandTypes(operator, getTypeFactory().createUnknownType(), operandList);
}
RexCallBinding callBinding =
new RexCallBinding(getTypeFactory(), operator, operandList, ImmutableList.of());
// check count of operands, if fail throw exceptions
Expand Down Expand Up @@ -955,102 +961,102 @@
* @param nodes build limit() if empty
* @return
*/
@Override
public GraphBuilder sortLimit(
@Nullable RexNode offsetNode,
@Nullable RexNode fetchNode,
Iterable<? extends RexNode> nodes) {
if (offsetNode != null && !(offsetNode instanceof RexLiteral)) {
throw new IllegalArgumentException("OFFSET node must be RexLiteral");
}
if (offsetNode != null && !(offsetNode instanceof RexLiteral)) {
throw new IllegalArgumentException("FETCH node must be RexLiteral");
}

RelNode input = requireNonNull(peek(), "frame stack is empty");

List<RelDataTypeField> originalFields = input.getRowType().getFieldList();

Registrar registrar = new Registrar(this, input, true);
List<RexNode> registerNodes = registrar.registerExpressions(ImmutableList.copyOf(nodes));

// expressions need to be projected in advance
if (!registrar.getExtraNodes().isEmpty()) {
project(registrar.getExtraNodes(), registrar.getExtraAliases(), registrar.isAppend());
RexTmpVariableConverter converter = new RexTmpVariableConverter(true, this);
registerNodes =
registerNodes.stream()
.map(k -> k.accept(converter))
.collect(Collectors.toList());
input = requireNonNull(peek(), "frame stack is empty");
}

List<RelFieldCollation> fieldCollations = fieldCollations(registerNodes);
Config config = Utils.getFieldValue(RelBuilder.class, this, "config");

// limit 0 -> return empty value
if ((fetchNode != null && RexLiteral.intValue(fetchNode) == 0) && config.simplifyLimit()) {
return (GraphBuilder) empty();
}

// output all results without any order -> skip
if (offsetNode == null && fetchNode == null && fieldCollations.isEmpty()) {
return this; // sort is trivial
}
// sortLimit is actually limit if collations are empty
if (fieldCollations.isEmpty()) {
// fuse limit with the previous sort operator
// order + limit -> topK
if (input instanceof Sort) {
Sort sort2 = (Sort) input;
// output all results without any limitations
if (sort2.offset == null && sort2.fetch == null) {
RelNode sort =
GraphLogicalSort.create(
sort2.getInput(), sort2.collation, offsetNode, fetchNode);
replaceTop(sort);
return this;
}
}
// order + project + limit -> topK + project
if (input instanceof Project) {
Project project = (Project) input;
if (project.getInput() instanceof Sort) {
Sort sort2 = (Sort) project.getInput();
if (sort2.offset == null && sort2.fetch == null) {
RelNode sort =
GraphLogicalSort.create(
sort2.getInput(), sort2.collation, offsetNode, fetchNode);
replaceTop(
GraphLogicalProject.create(
(GraphOptCluster) project.getCluster(),
project.getHints(),
sort,
project.getProjects(),
project.getRowType(),
((GraphLogicalProject) project).isAppend()));
return this;
}
}
}
}
RelNode sort =
GraphLogicalSort.create(
input, GraphRelCollations.of(fieldCollations), offsetNode, fetchNode);
replaceTop(sort);
// to remove the extra columns we have added
if (!registrar.getExtraAliases().isEmpty()) {
List<RexNode> originalExprs = new ArrayList<>();
List<String> originalAliases = new ArrayList<>();
for (RelDataTypeField field : originalFields) {
originalExprs.add(variable(field.getName()));
originalAliases.add(field.getName());
}
project(originalExprs, originalAliases, false);
}
return this;
}

Check notice on line 1059 in interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java

View check run for this annotation

codefactor.io / CodeFactor

interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java#L964-L1059

Complex Method
@Override
public RelBuilder join(
JoinRelType joinType, RexNode condition, Set<CorrelationId> variablesSet) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ public int generate(@Nullable String paramName) {
}

public ImmutableMap<Integer, String> getDynamicParams() {
return this.paramsBuilder.build();
return this.paramsBuilder.buildKeepingLast();
}

private RexLiteral createIntervalLiteral(String fieldName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.alibaba.graphscope.common.ir;

import com.alibaba.graphscope.common.ir.tools.GraphBuilder;
import com.alibaba.graphscope.common.ir.tools.GraphRexBuilder;
import com.alibaba.graphscope.common.ir.tools.GraphStdOperatorTable;
import com.alibaba.graphscope.common.ir.tools.config.GraphOpt;
import com.alibaba.graphscope.common.ir.tools.config.LabelConfig;
Expand Down Expand Up @@ -169,6 +170,18 @@ public void unary_minus_test() {
Assert.assertEquals("-(a.age)", node.toString());
}

// @age + 1
@Test
public void dynamic_param_type_test() {
GraphRexBuilder rexBuilder = (GraphRexBuilder) builder.getRexBuilder();
RexNode plus =
builder.call(
GraphStdOperatorTable.PLUS,
rexBuilder.makeGraphDynamicParam("age", 0),
builder.literal(1));
Assert.assertEquals(SqlTypeName.INTEGER, plus.getType().getSqlTypeName());
}

private SourceConfig mockSourceConfig(String alias) {
return new SourceConfig(
GraphOpt.Source.VERTEX, new LabelConfig(false).addLabel("person"), alias);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,24 @@ public void match_15_test() {
+ "], matchOpt=[INNER])",
node.explain().trim());
}

@Test
public void match_16_test() {
RelNode node =
Utils.eval(
"Match (a:person {name: $name})-[b]->(c:person {name: $name})"
+ " Return a, c")
.build();
Assert.assertEquals(
"GraphLogicalProject(a=[a], c=[c], isAppend=[false])\n"
+ " GraphLogicalSingleMatch(input=[null],"
+ " sentence=[GraphLogicalGetV(tableConfig=[{isAll=false, tables=[person]}],"
+ " alias=[c], fusedFilter=[[=(DEFAULT.name, ?0)]], opt=[END])\n"
+ " GraphLogicalExpand(tableConfig=[{isAll=true, tables=[created, knows]}],"
+ " alias=[b], opt=[OUT])\n"
+ " GraphLogicalSource(tableConfig=[{isAll=false, tables=[person]}],"
+ " alias=[a], fusedFilter=[[=(DEFAULT.name, ?0)]], opt=[VERTEX])\n"
+ "], matchOpt=[INNER])",
node.explain().trim());
}
}
Loading