Skip to content

Commit

Permalink
[GIE Compiler] convert continuous ranges search to compositions of 'a…
Browse files Browse the repository at this point in the history
…nd' or 'or'
  • Loading branch information
shirly121 committed Oct 12, 2023
1 parent d9b37f5 commit ae2ba09
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ public RelNode visit(GraphLogicalMultiMatch match) {
@Override
public RelNode visit(LogicalFilter logicalFilter) {
OuterExpression.Expression exprProto =
logicalFilter.getCondition().accept(new RexToProtoConverter(true, isColumnId));
logicalFilter
.getCondition()
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder));
Pointer ptrFilter = LIB.initSelectOperator();
checkFfiResult(
LIB.setSelectPredicatePb(
Expand All @@ -220,7 +222,9 @@ public PhysicalNode visit(GraphLogicalProject project) {
List<RelDataTypeField> fields = project.getRowType().getFieldList();
for (int i = 0; i < project.getProjects().size(); ++i) {
OuterExpression.Expression expression =
project.getProjects().get(i).accept(new RexToProtoConverter(true, isColumnId));
project.getProjects()
.get(i)
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder));
int aliasId = fields.get(i).getIndex();
FfiAlias.ByValue ffiAlias =
(aliasId == AliasInference.DEFAULT_ID)
Expand Down Expand Up @@ -261,7 +265,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) {
RexGraphVariable.class,
var.getClass());
OuterExpression.Expression expr =
var.accept(new RexToProtoConverter(true, isColumnId));
var.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder));
int aliasId;
if (i >= fields.size()
|| (aliasId = fields.get(i).getIndex()) == AliasInference.DEFAULT_ID) {
Expand All @@ -284,7 +288,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) {
field.getName(),
field.getType());
OuterExpression.Variable exprVar =
rexVar.accept(new RexToProtoConverter(true, isColumnId))
rexVar.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
checkFfiResult(
Expand All @@ -306,7 +310,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) {
OuterExpression.Variable var =
groupKeys
.get(i)
.accept(new RexToProtoConverter(true, isColumnId))
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
int aliasId = fields.get(i).getIndex();
Expand Down Expand Up @@ -340,7 +344,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) {
operands.get(0).getClass());
OuterExpression.Variable var =
operands.get(0)
.accept(new RexToProtoConverter(true, isColumnId))
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
checkFfiResult(
Expand Down Expand Up @@ -370,7 +374,7 @@ public PhysicalNode visit(GraphLogicalSort sort) {
for (int i = 0; i < collations.size(); ++i) {
RexGraphVariable expr = ((GraphFieldCollation) collations.get(i)).getVariable();
OuterExpression.Variable var =
expr.accept(new RexToProtoConverter(true, isColumnId))
expr.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
checkFfiResult(
Expand Down Expand Up @@ -406,13 +410,13 @@ public PhysicalNode visit(LogicalJoin join) {
OuterExpression.Variable leftVar =
leftRightVars
.get(0)
.accept(new RexToProtoConverter(true, isColumnId))
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
OuterExpression.Variable rightVar =
leftRightVars
.get(1)
.accept(new RexToProtoConverter(true, isColumnId))
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
checkFfiResult(
Expand Down Expand Up @@ -452,7 +456,10 @@ private Pointer ffiQueryParams(AbstractBindableTableScan tableScan) {
});
if (ObjectUtils.isNotEmpty(tableScan.getFilters())) {
OuterExpression.Expression expression =
tableScan.getFilters().get(0).accept(new RexToProtoConverter(true, isColumnId));
tableScan
.getFilters()
.get(0)
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder));
checkFfiResult(
LIB.setParamsPredicatePb(
params, new FfiPbPointer.ByValue(expression.toByteArray())));
Expand Down Expand Up @@ -632,7 +639,8 @@ private void addFilterToFfiBinder(Pointer ptrSentence, AbstractBindableTableScan
List<RexNode> filters = tableScan.getFilters();
if (ObjectUtils.isNotEmpty(filters)) {
OuterExpression.Expression exprProto =
filters.get(0).accept(new RexToProtoConverter(true, isColumnId));
filters.get(0)
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder));
Pointer ptrFilter = LIB.initSelectOperator();
checkFfiResult(
LIB.setSelectPredicatePb(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.calcite.rex.*;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.util.Sarg;

import java.util.List;

Expand All @@ -35,10 +36,12 @@
*/
public class RexToProtoConverter extends RexVisitorImpl<OuterExpression.Expression> {
private final boolean isColumnId;
private final RexBuilder rexBuilder;

public RexToProtoConverter(boolean deep, boolean isColumnId) {
public RexToProtoConverter(boolean deep, boolean isColumnId, RexBuilder rexBuilder) {
super(deep);
this.isColumnId = isColumnId;
this.rexBuilder = rexBuilder;
}

@Override
Expand Down Expand Up @@ -166,6 +169,26 @@ private OuterExpression.Expression visitIsNullOperator(RexNode operand) {
}

private OuterExpression.Expression visitBinaryOperator(RexCall call) {
if (call.getOperator().getKind() == SqlKind.SEARCH) {
// ir core can not support continuous ranges in a search operator, here expand it to
// compositions of 'and' or 'or',
// i.e. a.age SEARCH [[1, 10]] -> a.age >= 1 and a.age <= 10
RexNode left = call.getOperands().get(0);
RexNode right = call.getOperands().get(1);
RexLiteral literal = null;
if (left instanceof RexLiteral) {
literal = (RexLiteral) left;
} else if (right instanceof RexLiteral) {
literal = (RexLiteral) right;
}
if (literal != null && literal.getValue() instanceof Sarg) {
Sarg sarg = (Sarg) literal.getValue();
// search continuous ranges
if (!sarg.isPoints()) {
call = (RexCall) RexUtil.expandSearch(this.rexBuilder, null, call);
}
}
}
SqlOperator operator = call.getOperator();
OuterExpression.Expression.Builder exprBuilder = OuterExpression.Expression.newBuilder();
// left-associative
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
package com.alibaba.graphscope.common.ir.runtime.proto;

import com.alibaba.graphscope.common.ir.tools.config.GraphOpt;
import com.alibaba.graphscope.common.ir.type.*;
import com.alibaba.graphscope.common.ir.type.GraphLabelType;
import com.alibaba.graphscope.common.ir.type.GraphNameOrId;
import com.alibaba.graphscope.common.ir.type.GraphProperty;
import com.alibaba.graphscope.common.ir.type.GraphSchemaType;
import com.alibaba.graphscope.gaia.proto.Common;
import com.alibaba.graphscope.gaia.proto.DataType;
import com.alibaba.graphscope.gaia.proto.GraphAlgebra;
Expand All @@ -31,6 +34,7 @@
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.NlsString;
import org.apache.calcite.util.Sarg;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -69,8 +73,54 @@ public static final Common.Value protoValue(RexLiteral literal) {
return Common.Value.newBuilder()
.setF64(((Number) literal.getValue()).doubleValue())
.build();
case SARG:
Sarg sarg = literal.getValueAs(Sarg.class);
if (sarg.isPoints()) {
Common.Value.Builder valueBuilder = Common.Value.newBuilder();
List<Comparable> values =
com.alibaba.graphscope.common.ir.tools.Utils.getValuesAsList(sarg);
if (values.isEmpty()) {
// return an empty string array to handle the case i.e. within []
return valueBuilder
.setStrArray(Common.StringArray.newBuilder().build())
.build();
}
Comparable first = values.get(0);
if (first instanceof Integer) {
Common.I32Array.Builder i32Array = Common.I32Array.newBuilder();
values.forEach(value -> i32Array.addItem((Integer) value));
valueBuilder.setI32Array(i32Array);
} else if (first instanceof Number) {
Common.I64Array.Builder i64Array = Common.I64Array.newBuilder();
values.forEach(value -> i64Array.addItem(((Number) value).longValue()));
valueBuilder.setI64Array(i64Array);
} else if (first instanceof Double || first instanceof Float) {
Common.DoubleArray.Builder doubleArray = Common.DoubleArray.newBuilder();
values.forEach(
value ->
doubleArray.addItem(
(value instanceof Float)
? (Float) value
: (Double) value));
valueBuilder.setF64Array(doubleArray);
} else if (first instanceof String || first instanceof NlsString) {
Common.StringArray.Builder stringArray = Common.StringArray.newBuilder();
values.forEach(
value ->
stringArray.addItem(
(value instanceof NlsString)
? ((NlsString) value).getValue()
: (String) value));
valueBuilder.setStrArray(stringArray);
} else {
throw new UnsupportedOperationException(
"can not convert value list=" + values + " to ir core array");
}
return valueBuilder.build();
}
throw new UnsupportedOperationException(
"can not convert continuous ranges to ir core structure, sarg=" + sarg);
default:
// TODO: support int/double/string array
throw new UnsupportedOperationException(
"literal type " + literal.getTypeName() + " is unsupported yet");
}
Expand Down Expand Up @@ -180,8 +230,11 @@ public static final OuterExpression.ExprOpr protoOperator(SqlOperator operator)
return OuterExpression.ExprOpr.newBuilder()
.setLogical(OuterExpression.Logical.ISNULL)
.build();
case SEARCH:
return OuterExpression.ExprOpr.newBuilder()
.setLogical(OuterExpression.Logical.WITHIN)
.build();
default:
// TODO: support IN and NOT_IN
throw new UnsupportedOperationException(
"operator type="
+ operator.getKind()
Expand Down
Loading

0 comments on commit ae2ba09

Please sign in to comment.