diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ffi/RelToFfiConverter.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ffi/RelToFfiConverter.java index fb044d9e4402..02332052ab5f 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ffi/RelToFfiConverter.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ffi/RelToFfiConverter.java @@ -206,7 +206,9 @@ public RelNode visit(GraphLogicalMultiMatch match) { @Override public RelNode visit(LogicalFilter logicalFilter) { OuterExpression.Expression exprProto = - logicalFilter.getCondition().accept(new RexToProtoConverter(true, isColumnId)); + logicalFilter + .getCondition() + .accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)); Pointer ptrFilter = LIB.initSelectOperator(); checkFfiResult( LIB.setSelectPredicatePb( @@ -220,7 +222,9 @@ public PhysicalNode visit(GraphLogicalProject project) { List fields = project.getRowType().getFieldList(); for (int i = 0; i < project.getProjects().size(); ++i) { OuterExpression.Expression expression = - project.getProjects().get(i).accept(new RexToProtoConverter(true, isColumnId)); + project.getProjects() + .get(i) + .accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)); int aliasId = fields.get(i).getIndex(); FfiAlias.ByValue ffiAlias = (aliasId == AliasInference.DEFAULT_ID) @@ -261,7 +265,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) { RexGraphVariable.class, var.getClass()); OuterExpression.Expression expr = - var.accept(new RexToProtoConverter(true, isColumnId)); + var.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)); int aliasId; if (i >= fields.size() || (aliasId = fields.get(i).getIndex()) == AliasInference.DEFAULT_ID) { @@ -284,7 +288,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) { field.getName(), field.getType()); OuterExpression.Variable exprVar = - rexVar.accept(new RexToProtoConverter(true, isColumnId)) + rexVar.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)) .getOperators(0) .getVar(); checkFfiResult( @@ -306,7 +310,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) { OuterExpression.Variable var = groupKeys .get(i) - .accept(new RexToProtoConverter(true, isColumnId)) + .accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)) .getOperators(0) .getVar(); int aliasId = fields.get(i).getIndex(); @@ -340,7 +344,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) { operands.get(0).getClass()); OuterExpression.Variable var = operands.get(0) - .accept(new RexToProtoConverter(true, isColumnId)) + .accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)) .getOperators(0) .getVar(); checkFfiResult( @@ -370,7 +374,7 @@ public PhysicalNode visit(GraphLogicalSort sort) { for (int i = 0; i < collations.size(); ++i) { RexGraphVariable expr = ((GraphFieldCollation) collations.get(i)).getVariable(); OuterExpression.Variable var = - expr.accept(new RexToProtoConverter(true, isColumnId)) + expr.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)) .getOperators(0) .getVar(); checkFfiResult( @@ -406,13 +410,13 @@ public PhysicalNode visit(LogicalJoin join) { OuterExpression.Variable leftVar = leftRightVars .get(0) - .accept(new RexToProtoConverter(true, isColumnId)) + .accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)) .getOperators(0) .getVar(); OuterExpression.Variable rightVar = leftRightVars .get(1) - .accept(new RexToProtoConverter(true, isColumnId)) + .accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)) .getOperators(0) .getVar(); checkFfiResult( @@ -452,7 +456,10 @@ private Pointer ffiQueryParams(AbstractBindableTableScan tableScan) { }); if (ObjectUtils.isNotEmpty(tableScan.getFilters())) { OuterExpression.Expression expression = - tableScan.getFilters().get(0).accept(new RexToProtoConverter(true, isColumnId)); + tableScan + .getFilters() + .get(0) + .accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)); checkFfiResult( LIB.setParamsPredicatePb( params, new FfiPbPointer.ByValue(expression.toByteArray()))); @@ -632,7 +639,8 @@ private void addFilterToFfiBinder(Pointer ptrSentence, AbstractBindableTableScan List filters = tableScan.getFilters(); if (ObjectUtils.isNotEmpty(filters)) { OuterExpression.Expression exprProto = - filters.get(0).accept(new RexToProtoConverter(true, isColumnId)); + filters.get(0) + .accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder)); Pointer ptrFilter = LIB.initSelectOperator(); checkFfiResult( LIB.setSelectPredicatePb( diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/RexToProtoConverter.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/RexToProtoConverter.java index 3e57e4e77d3f..220074c03e5d 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/RexToProtoConverter.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/RexToProtoConverter.java @@ -27,6 +27,7 @@ import org.apache.calcite.rex.*; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.util.Sarg; import java.util.List; @@ -35,10 +36,12 @@ */ public class RexToProtoConverter extends RexVisitorImpl { private final boolean isColumnId; + private final RexBuilder rexBuilder; - public RexToProtoConverter(boolean deep, boolean isColumnId) { + public RexToProtoConverter(boolean deep, boolean isColumnId, RexBuilder rexBuilder) { super(deep); this.isColumnId = isColumnId; + this.rexBuilder = rexBuilder; } @Override @@ -166,6 +169,26 @@ private OuterExpression.Expression visitIsNullOperator(RexNode operand) { } private OuterExpression.Expression visitBinaryOperator(RexCall call) { + if (call.getOperator().getKind() == SqlKind.SEARCH) { + // ir core can not support continuous ranges in a search operator, here expand it to + // compositions of 'and' or 'or', + // i.e. a.age SEARCH [[1, 10]] -> a.age >= 1 and a.age <= 10 + RexNode left = call.getOperands().get(0); + RexNode right = call.getOperands().get(1); + RexLiteral literal = null; + if (left instanceof RexLiteral) { + literal = (RexLiteral) left; + } else if (right instanceof RexLiteral) { + literal = (RexLiteral) right; + } + if (literal != null && literal.getValue() instanceof Sarg) { + Sarg sarg = (Sarg) literal.getValue(); + // search continuous ranges + if (!sarg.isPoints()) { + call = (RexCall) RexUtil.expandSearch(this.rexBuilder, null, call); + } + } + } SqlOperator operator = call.getOperator(); OuterExpression.Expression.Builder exprBuilder = OuterExpression.Expression.newBuilder(); // left-associative diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java index 1245a82bb6d6..7345336353a0 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java @@ -17,7 +17,10 @@ package com.alibaba.graphscope.common.ir.runtime.proto; import com.alibaba.graphscope.common.ir.tools.config.GraphOpt; -import com.alibaba.graphscope.common.ir.type.*; +import com.alibaba.graphscope.common.ir.type.GraphLabelType; +import com.alibaba.graphscope.common.ir.type.GraphNameOrId; +import com.alibaba.graphscope.common.ir.type.GraphProperty; +import com.alibaba.graphscope.common.ir.type.GraphSchemaType; import com.alibaba.graphscope.gaia.proto.Common; import com.alibaba.graphscope.gaia.proto.DataType; import com.alibaba.graphscope.gaia.proto.GraphAlgebra; @@ -31,6 +34,7 @@ import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.util.NlsString; +import org.apache.calcite.util.Sarg; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -69,8 +73,54 @@ public static final Common.Value protoValue(RexLiteral literal) { return Common.Value.newBuilder() .setF64(((Number) literal.getValue()).doubleValue()) .build(); + case SARG: + Sarg sarg = literal.getValueAs(Sarg.class); + if (sarg.isPoints()) { + Common.Value.Builder valueBuilder = Common.Value.newBuilder(); + List values = + com.alibaba.graphscope.common.ir.tools.Utils.getValuesAsList(sarg); + if (values.isEmpty()) { + // return an empty string array to handle the case i.e. within [] + return valueBuilder + .setStrArray(Common.StringArray.newBuilder().build()) + .build(); + } + Comparable first = values.get(0); + if (first instanceof Integer) { + Common.I32Array.Builder i32Array = Common.I32Array.newBuilder(); + values.forEach(value -> i32Array.addItem((Integer) value)); + valueBuilder.setI32Array(i32Array); + } else if (first instanceof Number) { + Common.I64Array.Builder i64Array = Common.I64Array.newBuilder(); + values.forEach(value -> i64Array.addItem(((Number) value).longValue())); + valueBuilder.setI64Array(i64Array); + } else if (first instanceof Double || first instanceof Float) { + Common.DoubleArray.Builder doubleArray = Common.DoubleArray.newBuilder(); + values.forEach( + value -> + doubleArray.addItem( + (value instanceof Float) + ? (Float) value + : (Double) value)); + valueBuilder.setF64Array(doubleArray); + } else if (first instanceof String || first instanceof NlsString) { + Common.StringArray.Builder stringArray = Common.StringArray.newBuilder(); + values.forEach( + value -> + stringArray.addItem( + (value instanceof NlsString) + ? ((NlsString) value).getValue() + : (String) value)); + valueBuilder.setStrArray(stringArray); + } else { + throw new UnsupportedOperationException( + "can not convert value list=" + values + " to ir core array"); + } + return valueBuilder.build(); + } + throw new UnsupportedOperationException( + "can not convert continuous ranges to ir core structure, sarg=" + sarg); default: - // TODO: support int/double/string array throw new UnsupportedOperationException( "literal type " + literal.getTypeName() + " is unsupported yet"); } @@ -180,8 +230,11 @@ public static final OuterExpression.ExprOpr protoOperator(SqlOperator operator) return OuterExpression.ExprOpr.newBuilder() .setLogical(OuterExpression.Logical.ISNULL) .build(); + case SEARCH: + return OuterExpression.ExprOpr.newBuilder() + .setLogical(OuterExpression.Logical.WITHIN) + .build(); default: - // TODO: support IN and NOT_IN throw new UnsupportedOperationException( "operator type=" + operator.getKind() diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java index d917e58d0b59..d52d233838f6 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java @@ -38,7 +38,10 @@ import com.alibaba.graphscope.common.ir.type.*; import com.alibaba.graphscope.gremlin.Utils; import com.google.common.base.Preconditions; -import com.google.common.collect.*; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import org.apache.calcite.plan.*; import org.apache.calcite.rel.AbstractRelNode; @@ -55,7 +58,6 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.Litmus; -import org.apache.calcite.util.NlsString; import org.apache.calcite.util.Pair; import org.apache.calcite.util.Sarg; import org.apache.commons.lang3.ObjectUtils; @@ -819,25 +821,13 @@ private void classifyFilters( List conjunctions = RelOptUtil.conjunctions(condition); List filtersToRemove = Lists.newArrayList(); for (RexNode conjunction : conjunctions) { - if (conjunction instanceof RexCall) { - RexCall rexCall = (RexCall) conjunction; - if (rexCall.getOperator().getKind() == SqlKind.EQUALS - || rexCall.getOperator().getKind() == SqlKind.SEARCH) { - RexNode left = rexCall.getOperands().get(0); - RexNode right = rexCall.getOperands().get(1); - if (left.getType() instanceof GraphLabelType && right instanceof RexLiteral) { - filtersToRemove.add(conjunction); - labelValues.addAll( - getValuesAsList(((RexLiteral) right).getValueAs(Comparable.class))); - break; - } else if (left instanceof RexLiteral - && right.getType() instanceof GraphLabelType) { - filtersToRemove.add(conjunction); - labelValues.addAll( - getValuesAsList(((RexLiteral) left).getValueAs(Comparable.class))); - break; - } - } + RexLiteral labelLiteral = isLabelEqualFilter(conjunction); + if (labelLiteral != null) { + filtersToRemove.add(conjunction); + labelValues.addAll( + com.alibaba.graphscope.common.ir.tools.Utils.getValuesAsList( + labelLiteral.getValueAs(Comparable.class))); + break; } } if (tableScan instanceof GraphLogicalSource @@ -845,20 +835,9 @@ private void classifyFilters( // try to extract unique key filters from the original condition List disjunctions = RelOptUtil.disjunctions(condition); for (RexNode disjunction : disjunctions) { - if (disjunction instanceof RexCall) { - RexCall rexCall = (RexCall) disjunction; - if (rexCall.getOperator().getKind() == SqlKind.EQUALS - || rexCall.getOperator().getKind() == SqlKind.SEARCH) { - RexNode left = rexCall.getOperands().get(0); - RexNode right = rexCall.getOperands().get(1); - if (isUniqueKey(left) && isLiteralOrDynamicParams(right)) { - filtersToRemove.add(disjunction); - uniqueKeyFilters.add(disjunction); - } else if (isLiteralOrDynamicParams(left) && isUniqueKey(right)) { - filtersToRemove.add(disjunction); - uniqueKeyFilters.add(disjunction); - } - } + if (isUniqueKeyEqualFilter(disjunction)) { + filtersToRemove.add(disjunction); + uniqueKeyFilters.add(disjunction); } } } @@ -868,6 +847,80 @@ private void classifyFilters( filters.addAll(conjunctions); } + // check the condition if it is the pattern of label equal filter, i.e. ~label = 'person' or + // ~label within ['person', 'software'] + // if it is then return the literal containing label values, otherwise null + private @Nullable RexLiteral isLabelEqualFilter(RexNode condition) { + if (condition instanceof RexCall) { + RexCall rexCall = (RexCall) condition; + SqlOperator operator = rexCall.getOperator(); + switch (operator.getKind()) { + case EQUALS: + case SEARCH: + RexNode left = rexCall.getOperands().get(0); + RexNode right = rexCall.getOperands().get(1); + if (left.getType() instanceof GraphLabelType && right instanceof RexLiteral) { + Comparable value = ((RexLiteral) right).getValue(); + // if Sarg is a continuous range then the filter is not the 'equal', i.e. + // ~label SEARCH [[1, 10]] which means ~label >= 1 and ~label <= 10 + if (value instanceof Sarg && !((Sarg) value).isPoints()) { + return null; + } + return (RexLiteral) right; + } else if (right.getType() instanceof GraphLabelType + && left instanceof RexLiteral) { + Comparable value = ((RexLiteral) left).getValue(); + if (value instanceof Sarg && !((Sarg) value).isPoints()) { + return null; + } + return (RexLiteral) left; + } + default: + return null; + } + } else { + return null; + } + } + + // check the condition if it is the pattern of unique key equal filter, i.e. ~id = 1 or ~id + // within [1, 2] + private boolean isUniqueKeyEqualFilter(RexNode condition) { + if (condition instanceof RexCall) { + RexCall rexCall = (RexCall) condition; + SqlOperator operator = rexCall.getOperator(); + switch (operator.getKind()) { + case EQUALS: + case SEARCH: + RexNode left = rexCall.getOperands().get(0); + RexNode right = rexCall.getOperands().get(1); + if (isUniqueKey(left) && isLiteralOrDynamicParams(right)) { + if (right instanceof RexLiteral) { + Comparable value = ((RexLiteral) right).getValue(); + // if Sarg is a continuous range then the filter is not the 'equal', + // i.e. ~id SEARCH [[1, 10]] which means ~id >= 1 and ~id <= 10 + if (value instanceof Sarg && !((Sarg) value).isPoints()) { + return false; + } + } + return true; + } else if (isUniqueKey(right) && isLiteralOrDynamicParams(left)) { + if (left instanceof RexLiteral) { + Comparable value = ((RexLiteral) left).getValue(); + if (value instanceof Sarg && !((Sarg) value).isPoints()) { + return false; + } + } + return true; + } + default: + return false; + } + } else { + return false; + } + } + private boolean isUniqueKey(RexNode rexNode) { // todo: support primary keys if (rexNode instanceof RexGraphVariable) { @@ -882,24 +935,6 @@ private boolean isLiteralOrDynamicParams(RexNode node) { return node instanceof RexLiteral || node instanceof RexDynamicParam; } - private List getValuesAsList(Comparable value) { - ImmutableList.Builder labelBuilder = ImmutableList.builder(); - if (value instanceof NlsString) { - labelBuilder.add(((NlsString) value).getValue()); - } else if (value instanceof Sarg) { - Sarg sarg = (Sarg) value; - if (sarg.isPoints()) { - Set> rangeSets = sarg.rangeSet.asRanges(); - for (Range range : rangeSets) { - labelBuilder.addAll(getValuesAsList(range.lowerEndpoint())); - } - } - } else { - labelBuilder.add(value); - } - return labelBuilder.build(); - } - // return the top node if its type is Filter, otherwise null private Filter topFilter() { if (this.size() > 0 && this.peek() instanceof Filter) { diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/Utils.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/Utils.java index 701cc3ebb337..ea970934c9fd 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/Utils.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/Utils.java @@ -16,7 +16,9 @@ package com.alibaba.graphscope.common.ir.tools; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.common.collect.Range; import com.google.common.collect.Sets; import org.apache.calcite.rel.RelNode; @@ -24,6 +26,8 @@ import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rel.type.RelRecordType; import org.apache.calcite.rel.type.StructKind; +import org.apache.calcite.util.NlsString; +import org.apache.calcite.util.Sarg; import java.util.List; import java.util.Set; @@ -53,4 +57,22 @@ public static RelDataType getOutputType(RelNode topNode) { .collect(Collectors.toList()); return new RelRecordType(StructKind.FULLY_QUALIFIED, dedup); } + + public static List getValuesAsList(Comparable value) { + ImmutableList.Builder valueBuilder = ImmutableList.builder(); + if (value instanceof NlsString) { + valueBuilder.add(((NlsString) value).getValue()); + } else if (value instanceof Sarg) { + Sarg sarg = (Sarg) value; + if (sarg.isPoints()) { + Set> rangeSets = sarg.rangeSet.asRanges(); + for (Range range : rangeSets) { + valueBuilder.addAll(getValuesAsList(range.lowerEndpoint())); + } + } + } else { + valueBuilder.add(value); + } + return valueBuilder.build(); + } } diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/runtime/RexToProtoTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/runtime/RexToProtoTest.java index 311f304f318a..191e112a274c 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/runtime/RexToProtoTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/runtime/RexToProtoTest.java @@ -50,7 +50,7 @@ public void test_expression_with_brace() throws Exception { GraphStdOperatorTable.MINUS, builder.literal(1), builder.variable("a", "age"))); - RexToProtoConverter converter = new RexToProtoConverter(true, false); + RexToProtoConverter converter = new RexToProtoConverter(true, false, Utils.rexBuilder); Assert.assertEquals( FileUtils.readJsonFromResource("proto/expression_with_brace.json"), JsonFormat.printer().print(braceExpr.accept(converter)));