From 039fce0a5dfe2bc7d2fcc44933b2de1af3467a85 Mon Sep 17 00:00:00 2001 From: Xiaoli Zhou Date: Thu, 9 May 2024 11:35:05 +0800 Subject: [PATCH] fix(interactive): Fix Bugs of `Where Subquery` in Gremlin (#3757) ## What do these changes do? fix bugs of `where subquery` in gremlin: ``` g.V().as("a").out().as("b").select("a", "b").where(__.as("a").out("knows").as("b")) g.V().as("a").out().in().as("b").select("a", "b").where(__.as("b").has("name", "marko")) ``` ## Related issue number Fixes --- .../compiler/conf/ir.compiler.properties | 4 +- .../antlr4x/visitor/GraphBuilderVisitor.java | 44 +++++++++- .../visitor/NestedTraversalRexVisitor.java | 68 +++++++++++---- .../visitor/TraversalMethodIterator.java | 5 -- .../suite/standard/IrGremlinQueryTest.java | 82 +++++++++++++++++++ .../gremlin/antlr4x/GraphBuilderTest.java | 56 +++++++++++++ 6 files changed, 237 insertions(+), 22 deletions(-) diff --git a/interactive_engine/compiler/conf/ir.compiler.properties b/interactive_engine/compiler/conf/ir.compiler.properties index f1da906c4b2e..7f18a0271721 100644 --- a/interactive_engine/compiler/conf/ir.compiler.properties +++ b/interactive_engine/compiler/conf/ir.compiler.properties @@ -22,7 +22,7 @@ graph.planner.opt: RBO graph.planner.rules: FilterIntoJoinRule, FilterMatchRule, ExtendIntersectRule, ExpandGetVFusionRule # set file path of glogue input statistics -# graph.planner.cbo.glogue.schema: +# graph.planner.cbo.glogue.schema: src/test/resources/statistics/modern_statistics.txt # set stored procedures directory path # graph.stored.procedures: @@ -60,4 +60,4 @@ calcite.default.charset: UTF-8 # physical.opt.config: ffi # set the max capacity of the result streaming buffer for each query -# per.query.stream.buffer.max.capacity: 256 \ No newline at end of file +# per.query.stream.buffer.max.capacity: 256 diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/GraphBuilderVisitor.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/GraphBuilderVisitor.java index 10aa08c94f72..816513273e60 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/GraphBuilderVisitor.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/GraphBuilderVisitor.java @@ -41,6 +41,8 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.RuleNode; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.Project; @@ -58,20 +60,48 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; public class GraphBuilderVisitor extends GremlinGSBaseVisitor { private final GraphBuilder builder; private final ExprUniqueAliasInfer aliasInfer; + private final Predicate trimAlias; public GraphBuilderVisitor(GraphBuilder builder) { this(builder, new ExprUniqueAliasInfer()); } public GraphBuilderVisitor(GraphBuilder builder, ExprUniqueAliasInfer aliasInfer) { + this(builder, aliasInfer, t -> false); + } + + public GraphBuilderVisitor(GraphBuilder builder, Predicate trimAlias) { + this(builder, new ExprUniqueAliasInfer(), trimAlias); + } + + public GraphBuilderVisitor( + GraphBuilder builder, ExprUniqueAliasInfer aliasInfer, Predicate trimAlias) { this.builder = Objects.requireNonNull(builder); this.aliasInfer = Objects.requireNonNull(aliasInfer); + this.trimAlias = Objects.requireNonNull(trimAlias); + } + + // re-implement the function in that we need to customize the behavior to visit each parse tree + // node + @Override + public GraphBuilder visitChildren(RuleNode node) { + GraphBuilder result = this.defaultResult(); + int n = node.getChildCount(); + + for (int i = 0; i < n && this.shouldVisitNextChild(node, result); ++i) { + ParseTree c = node.getChild(i); + GraphBuilder childResult = (!trimAlias.test(c)) ? c.accept(this) : this.builder; + result = this.aggregateResult(result, childResult); + } + + return result; } @Override @@ -620,8 +650,20 @@ public GraphBuilder visitTraversalMethod_where( } return builder.filter(exprRes.getExpr()); } else if (ctx.nestedTraversal() != null) { + TraversalMethodIterator methodIterator = + new TraversalMethodIterator(ctx.nestedTraversal()); + String alias = null; + if (methodIterator.hasNext()) { + GremlinGSParser.TraversalMethodContext methodCtx = methodIterator.next(); + if (methodCtx.traversalMethod_as() != null) { + alias = + (String) + LiteralVisitor.INSTANCE.visit( + methodCtx.traversalMethod_as().StringLiteral()); + } + } RexNode subQuery = - (new NestedTraversalRexVisitor(builder, null, ctx)) + (new NestedTraversalRexVisitor(builder, alias, ctx)) .visitNestedTraversal(ctx.nestedTraversal()); return builder.filter(Utils.convertExprToPair(subQuery).getValue0()); } else if (ctx.traversalMethod_not() != null) { diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/NestedTraversalRexVisitor.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/NestedTraversalRexVisitor.java index 283d957bdd4b..680269c8e07b 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/NestedTraversalRexVisitor.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/NestedTraversalRexVisitor.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableList; import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.tree.ParseTree; import org.apache.calcite.plan.GraphOptCluster; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Project; @@ -33,22 +34,23 @@ import org.checkerframework.checker.nullness.qual.Nullable; import java.util.Objects; +import java.util.function.Predicate; public class NestedTraversalRexVisitor extends GremlinGSBaseVisitor { private final GraphBuilder parentBuilder; private final GraphBuilder nestedBuilder; - private final @Nullable String tag; + private final @Nullable String headAlias; private final ParserRuleContext parentCtx; public NestedTraversalRexVisitor( - GraphBuilder parentBuilder, @Nullable String tag, ParserRuleContext parentCtx) { + GraphBuilder parentBuilder, @Nullable String headAlias, ParserRuleContext parentCtx) { this.parentBuilder = parentBuilder; this.nestedBuilder = GraphBuilder.create( this.parentBuilder.getContext(), (GraphOptCluster) this.parentBuilder.getCluster(), this.parentBuilder.getRelOptSchema()); - this.tag = tag; + this.headAlias = headAlias; this.parentCtx = parentCtx; } @@ -63,30 +65,45 @@ public RexNode visitNestedTraversal(GremlinGSParser.NestedTraversalContext ctx) commonRel.getTraitSet(), new CommonOptTable(commonRel)); nestedBuilder.push(commonRel); - if (tag != null) { + if (headAlias != null) { nestedBuilder.project( - ImmutableList.of(nestedBuilder.variable(tag)), ImmutableList.of(), true); + ImmutableList.of(nestedBuilder.variable(headAlias)), ImmutableList.of(), true); } - GraphBuilderVisitor visitor = new GraphBuilderVisitor(nestedBuilder); - RelNode subRel = visitor.visitNestedTraversal(ctx).build(); - String alias = null; // set query given alias + String tailAlias = null; + int methodCounter = 0; TraversalMethodIterator iterator = new TraversalMethodIterator(ctx); while (iterator.hasNext()) { GremlinGSParser.TraversalMethodContext cur = iterator.next(); - if (cur.traversalMethod_as() != null) { - alias = + if (methodCounter != 0 && !iterator.hasNext() && cur.traversalMethod_as() != null) { + tailAlias = (String) LiteralVisitor.INSTANCE.visit( cur.traversalMethod_as().StringLiteral()); } + ++methodCounter; } - if (alias != null) { - subRel = nestedBuilder.push(subRel).as(null).build(); - } + TrimAlias trimAlias = new TrimAlias(methodCounter); + GraphBuilderVisitor visitor = new GraphBuilderVisitor(nestedBuilder, trimAlias); + // skip head and tail aliases in nested traversal, we have handled them specifically in + // current context. + RelNode subRel = visitor.visitNestedTraversal(ctx).build(); RexNode expr; if (new SubQueryChecker(commonRel).test(subRel)) { if (parentCtx instanceof GremlinGSParser.TraversalMethod_whereContext) { + if (tailAlias != null) { + // specific implementation for `where(('a').out().as('b')`, convert tail alias + // `as('b')` to `where(eq('b'))` + subRel = + nestedBuilder + .push(subRel) + .filter( + nestedBuilder.equals( + nestedBuilder.variable((String) null), + nestedBuilder.variable(tailAlias))) + .build(); + tailAlias = null; + } expr = RexSubQuery.exists(subRel); } else if (parentCtx instanceof GremlinGSParser.TraversalMethod_notContext) { expr = nestedBuilder.not(RexSubQuery.exists(subRel)); // convert to not exist @@ -115,6 +132,29 @@ public RexNode visitNestedTraversal(GremlinGSParser.NestedTraversalContext ctx) expr = nestedBuilder.isNull(expr); } } - return nestedBuilder.alias(expr, alias); + return nestedBuilder.alias(expr, tailAlias); + } + + private class TrimAlias implements Predicate { + private int methodIdx; + private final int methodCount; + + public TrimAlias(int methodCount) { + this.methodCount = methodCount; + this.methodIdx = 0; + } + + @Override + public boolean test(ParseTree parseTree) { + if (parseTree.getParent() instanceof GremlinGSParser.TraversalMethodContext) { + boolean toSkip = + parseTree instanceof GremlinGSParser.TraversalMethod_asContext + && (methodIdx == 0 || methodIdx == methodCount - 1); + ++methodIdx; + return toSkip; + } else { + return false; + } + } } } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/TraversalMethodIterator.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/TraversalMethodIterator.java index 2368a1d741dd..eba3c21058b3 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/TraversalMethodIterator.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/antlr4x/visitor/TraversalMethodIterator.java @@ -79,11 +79,6 @@ public GremlinGSParser.TraversalMethodContext next() { private @Nullable ParseTree getParent(ParseTree child) { Class parentClass = GremlinGSParser.ChainedTraversalContext.class; - while (child != null - && child.getParent() != null - && !child.getParent().getClass().equals(parentClass)) { - child = child.getParent(); - } return (child != null && child.getParent() != null && child.getParent().getClass().equals(parentClass)) diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/integration/suite/standard/IrGremlinQueryTest.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/integration/suite/standard/IrGremlinQueryTest.java index 57a98e87add5..c82bea1a7b10 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/integration/suite/standard/IrGremlinQueryTest.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/gremlin/integration/suite/standard/IrGremlinQueryTest.java @@ -16,7 +16,9 @@ package com.alibaba.graphscope.gremlin.integration.suite.standard; +import static org.apache.tinkerpop.gremlin.LoadGraphWith.GraphData.MODERN; import static org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.__.*; +import static org.junit.Assert.*; import static org.junit.Assume.assumeFalse; import static org.junit.Assume.assumeTrue; @@ -229,6 +231,60 @@ public void g_V_where_expr_name_equal_marko_and_age_gt_20_or_age_lt_10_name() { public abstract Traversal get_g_V_where_not_values_age_count(); + public abstract Traversal> + get_g_V_hasXageX_asXaX_out_in_hasXageX_asXbX_selectXa_bX_whereXa_outXknowsX_bX(); + + public abstract Traversal> + get_g_V_hasXageX_asXaX_out_in_hasXageX_asXbX_selectXa_bX_whereXb_hasXname_markoXX(); + + @Test + @LoadGraphWith(MODERN) + public void g_V_hasXageX_asXaX_out_in_hasXageX_asXbX_selectXa_bX_whereXa_outXknowsX_bX() { + assumeFalse("hiactor".equals(System.getenv("ENGINE_TYPE"))); + final Traversal> traversal = + get_g_V_hasXageX_asXaX_out_in_hasXageX_asXbX_selectXa_bX_whereXa_outXknowsX_bX(); + printTraversalForm(traversal); + int counter = 0; + while (traversal.hasNext()) { + counter++; + final Map map = traversal.next(); + assertEquals(2, map.size()); + assertTrue(map.containsKey("a")); + assertTrue(map.containsKey("b")); + assertEquals(convertToVertexId("marko"), ((Vertex) map.get("a")).id()); + assertEquals(convertToVertexId("josh"), ((Vertex) map.get("b")).id()); + } + assertEquals(1, counter); + assertFalse(traversal.hasNext()); + } + + @Test + @LoadGraphWith(MODERN) + public void g_V_hasXageX_asXaX_out_in_hasXageX_asXbX_selectXa_bX_whereXb_hasXname_markoXX() { + assumeFalse("hiactor".equals(System.getenv("ENGINE_TYPE"))); + final Traversal> traversal = + get_g_V_hasXageX_asXaX_out_in_hasXageX_asXbX_selectXa_bX_whereXb_hasXname_markoXX(); + printTraversalForm(traversal); + int counter = 0; + int markoCounter = 0; + while (traversal.hasNext()) { + counter++; + final Map map = traversal.next(); + assertEquals(2, map.size()); + assertTrue(map.containsKey("a")); + assertTrue(map.containsKey("b")); + assertEquals(convertToVertexId("marko"), ((Vertex) map.get("b")).id()); + if (((Vertex) map.get("a")).id().equals(convertToVertexId("marko"))) markoCounter++; + else + assertTrue( + ((Vertex) map.get("a")).id().equals(convertToVertexId("josh")) + || ((Vertex) map.get("a")).id().equals(convertToVertexId("peter"))); + } + assertEquals(3, markoCounter); + assertEquals(5, counter); + assertFalse(traversal.hasNext()); + } + @LoadGraphWith(LoadGraphWith.GraphData.MODERN) @Test public void g_V_where_out_out_count() { @@ -1463,5 +1519,31 @@ public Traversal get_g_VX1X_outE_asXhereX_inV_hasXname_vadasX_sele public Traversal get_g_E_hasLabelXknowsX() { return g.E().hasLabel("knows"); } + + @Override + public Traversal> + get_g_V_hasXageX_asXaX_out_in_hasXageX_asXbX_selectXa_bX_whereXa_outXknowsX_bX() { + return g.V().has("age") + .as("a") + .out() + .in() + .has("age") + .as("b") + .select("a", "b") + .where(as("a").out("knows").as("b")); + } + + @Override + public Traversal> + get_g_V_hasXageX_asXaX_out_in_hasXageX_asXbX_selectXa_bX_whereXb_hasXname_markoXX() { + return g.V().has("age") + .as("a") + .out() + .in() + .has("age") + .as("b") + .select("a", "b") + .where(as("b").has("name", "marko")); + } } } diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/gremlin/antlr4x/GraphBuilderTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/gremlin/antlr4x/GraphBuilderTest.java index d2a2219b49ba..0884fb238f56 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/gremlin/antlr4x/GraphBuilderTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/gremlin/antlr4x/GraphBuilderTest.java @@ -1179,6 +1179,62 @@ public void g_V_where_not_out_out_test() { node.explain().trim()); } + @Test + public void g_V_where_as_a_out_as_b_select_a_b_where_as_a_out_as_b_test() { + RelNode node = + eval( + "g.V().as(\"a\").out().as(\"b\").select(\"a\"," + + " \"b\").where(as('a').out('knows').as('b'))"); + RelOptPlanner planner = + Utils.mockPlanner(ExpandGetVFusionRule.BasicExpandGetVFusionRule.Config.DEFAULT); + planner.setRoot(node); + node = planner.findBestExp(); + Assert.assertEquals( + "GraphLogicalProject($f0=[$f0], isAppend=[false])\n" + + " LogicalFilter(condition=[EXISTS({\n" + + "LogicalFilter(condition=[=(_, b)])\n" + + " GraphPhysicalExpand(tableConfig=[{isAll=false, tables=[knows]}]," + + " alias=[_], opt=[OUT], physicalOpt=[VERTEX])\n" + + " GraphLogicalProject(_=[a], isAppend=[true])\n" + + " CommonTableScan(table=[[common#-630412335]])\n" + + "})])\n" + + " GraphLogicalProject($f0=[MAP(_UTF-8'a', a, _UTF-8'b', b)]," + + " isAppend=[true])\n" + + " GraphPhysicalExpand(tableConfig=[{isAll=true, tables=[created," + + " knows]}], alias=[b], opt=[OUT], physicalOpt=[VERTEX])\n" + + " GraphLogicalSource(tableConfig=[{isAll=true, tables=[software," + + " person]}], alias=[a], opt=[VERTEX])", + node.explain().trim()); + } + + @Test + public void g_V_where_as_a_out_as_b_select_a_b_where_as_a_has_name_marko_test() { + RelNode node = + eval( + "g.V().as(\"a\").out().in().as(\"b\").select(\"a\"," + + " \"b\").where(__.as(\"b\").has(\"name\", \"marko\"))"); + RelOptPlanner planner = + Utils.mockPlanner(ExpandGetVFusionRule.BasicExpandGetVFusionRule.Config.DEFAULT); + planner.setRoot(node); + node = planner.findBestExp(); + Assert.assertEquals( + "GraphLogicalProject($f0=[$f0], isAppend=[false])\n" + + " LogicalFilter(condition=[EXISTS({\n" + + "LogicalFilter(condition=[=(_.name, _UTF-8'marko')])\n" + + " GraphLogicalProject(_=[b], isAppend=[true])\n" + + " CommonTableScan(table=[[common#1582012392]])\n" + + "})])\n" + + " GraphLogicalProject($f0=[MAP(_UTF-8'a', a, _UTF-8'b', b)]," + + " isAppend=[true])\n" + + " GraphPhysicalExpand(tableConfig=[{isAll=true, tables=[created," + + " knows]}], alias=[b], opt=[IN], physicalOpt=[VERTEX])\n" + + " GraphPhysicalExpand(tableConfig=[{isAll=true, tables=[created," + + " knows]}], alias=[_], opt=[OUT], physicalOpt=[VERTEX])\n" + + " GraphLogicalSource(tableConfig=[{isAll=true, tables=[software," + + " person]}], alias=[a], opt=[VERTEX])", + node.explain().trim()); + } + @Test public void g_V_where_a_neq_b_by_out_count_test() { RelNode node =