Skip to content

Commit

Permalink
[GIE Compiler] classify filters into label equals or unique key equal…
Browse files Browse the repository at this point in the history
…s or others
  • Loading branch information
shirly121 committed Oct 12, 2023
1 parent caabbbd commit 0d24045
Show file tree
Hide file tree
Showing 16 changed files with 1,373 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
import org.apache.calcite.plan.GraphOptCluster;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rex.RexNode;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.List;
import java.util.Objects;

public class GraphLogicalSource extends AbstractBindableTableScan {
private final GraphOpt.Source opt;
private @Nullable RexNode uniqueKeyFilters;

protected GraphLogicalSource(
GraphOptCluster cluster,
Expand All @@ -54,6 +57,16 @@ public GraphOpt.Source getOpt() {

@Override
public RelWriter explainTerms(RelWriter pw) {
return super.explainTerms(pw).item("opt", getOpt());
return super.explainTerms(pw)
.item("opt", getOpt())
.itemIf("uniqueKeyFilters", uniqueKeyFilters, uniqueKeyFilters != null);
}

public void setUniqueKeyFilters(RexNode uniqueKeyFilters) {
this.uniqueKeyFilters = Objects.requireNonNull(uniqueKeyFilters);
}

public @Nullable RexNode getUniqueKeyFilters() {
return uniqueKeyFilters;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public FfiPhysicalBuilder(
Configs graphConfig, IrMeta irMeta, LogicalPlan logicalPlan, PlanPointer planPointer) {
super(
logicalPlan,
new GraphRelShuttleWrapper(new RelToFfiConverter(irMeta.getSchema().isColumnId())));
new GraphRelShuttleWrapper(
new RelToFfiConverter(irMeta.getSchema().isColumnId(), graphConfig)));
this.graphConfig = graphConfig;
this.irMeta = irMeta;
this.planPointer = Objects.requireNonNull(planPointer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.alibaba.graphscope.common.ir.runtime.ffi;

import com.alibaba.graphscope.common.config.Configs;
import com.alibaba.graphscope.common.intermediate.ArgUtils;
import com.alibaba.graphscope.common.ir.rel.GraphLogicalAggregate;
import com.alibaba.graphscope.common.ir.rel.GraphLogicalProject;
Expand All @@ -31,8 +32,10 @@
import com.alibaba.graphscope.common.ir.runtime.proto.RexToProtoConverter;
import com.alibaba.graphscope.common.ir.runtime.type.PhysicalNode;
import com.alibaba.graphscope.common.ir.tools.AliasInference;
import com.alibaba.graphscope.common.ir.tools.GraphPlanner;
import com.alibaba.graphscope.common.ir.tools.config.GraphOpt;
import com.alibaba.graphscope.common.ir.type.GraphLabelType;
import com.alibaba.graphscope.common.ir.type.GraphNameOrId;
import com.alibaba.graphscope.common.ir.type.GraphProperty;
import com.alibaba.graphscope.common.ir.type.GraphSchemaType;
import com.alibaba.graphscope.common.jna.IrCoreLibrary;
Expand All @@ -49,17 +52,13 @@
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVariable;
import org.apache.calcite.rex.*;
import org.apache.calcite.sql.SqlKind;
import org.apache.commons.lang3.ObjectUtils;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
Expand All @@ -72,9 +71,11 @@ public class RelToFfiConverter implements GraphRelShuttle {
private static final Logger logger = LoggerFactory.getLogger(RelToFfiConverter.class);
private static final IrCoreLibrary LIB = IrCoreLibrary.INSTANCE;
private final boolean isColumnId;
private final RexBuilder rexBuilder;

public RelToFfiConverter(boolean isColumnId) {
public RelToFfiConverter(boolean isColumnId, Configs configs) {
this.isColumnId = isColumnId;
this.rexBuilder = GraphPlanner.rexBuilderFactory.apply(configs);
}

@Override
Expand Down Expand Up @@ -205,7 +206,9 @@ public RelNode visit(GraphLogicalMultiMatch match) {
@Override
public RelNode visit(LogicalFilter logicalFilter) {
OuterExpression.Expression exprProto =
logicalFilter.getCondition().accept(new RexToProtoConverter(true, isColumnId));
logicalFilter
.getCondition()
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder));
Pointer ptrFilter = LIB.initSelectOperator();
checkFfiResult(
LIB.setSelectPredicatePb(
Expand All @@ -219,7 +222,9 @@ public PhysicalNode visit(GraphLogicalProject project) {
List<RelDataTypeField> fields = project.getRowType().getFieldList();
for (int i = 0; i < project.getProjects().size(); ++i) {
OuterExpression.Expression expression =
project.getProjects().get(i).accept(new RexToProtoConverter(true, isColumnId));
project.getProjects()
.get(i)
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder));
int aliasId = fields.get(i).getIndex();
FfiAlias.ByValue ffiAlias =
(aliasId == AliasInference.DEFAULT_ID)
Expand Down Expand Up @@ -260,7 +265,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) {
RexGraphVariable.class,
var.getClass());
OuterExpression.Expression expr =
var.accept(new RexToProtoConverter(true, isColumnId));
var.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder));
int aliasId;
if (i >= fields.size()
|| (aliasId = fields.get(i).getIndex()) == AliasInference.DEFAULT_ID) {
Expand All @@ -283,7 +288,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) {
field.getName(),
field.getType());
OuterExpression.Variable exprVar =
rexVar.accept(new RexToProtoConverter(true, isColumnId))
rexVar.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
checkFfiResult(
Expand All @@ -305,7 +310,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) {
OuterExpression.Variable var =
groupKeys
.get(i)
.accept(new RexToProtoConverter(true, isColumnId))
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
int aliasId = fields.get(i).getIndex();
Expand Down Expand Up @@ -339,7 +344,7 @@ public PhysicalNode visit(GraphLogicalAggregate aggregate) {
operands.get(0).getClass());
OuterExpression.Variable var =
operands.get(0)
.accept(new RexToProtoConverter(true, isColumnId))
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
checkFfiResult(
Expand Down Expand Up @@ -369,7 +374,7 @@ public PhysicalNode visit(GraphLogicalSort sort) {
for (int i = 0; i < collations.size(); ++i) {
RexGraphVariable expr = ((GraphFieldCollation) collations.get(i)).getVariable();
OuterExpression.Variable var =
expr.accept(new RexToProtoConverter(true, isColumnId))
expr.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
checkFfiResult(
Expand Down Expand Up @@ -405,13 +410,13 @@ public PhysicalNode visit(LogicalJoin join) {
OuterExpression.Variable leftVar =
leftRightVars
.get(0)
.accept(new RexToProtoConverter(true, isColumnId))
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
OuterExpression.Variable rightVar =
leftRightVars
.get(1)
.accept(new RexToProtoConverter(true, isColumnId))
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder))
.getOperators(0)
.getVar();
checkFfiResult(
Expand Down Expand Up @@ -451,7 +456,10 @@ private Pointer ffiQueryParams(AbstractBindableTableScan tableScan) {
});
if (ObjectUtils.isNotEmpty(tableScan.getFilters())) {
OuterExpression.Expression expression =
tableScan.getFilters().get(0).accept(new RexToProtoConverter(true, isColumnId));
tableScan
.getFilters()
.get(0)
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder));
checkFfiResult(
LIB.setParamsPredicatePb(
params, new FfiPbPointer.ByValue(expression.toByteArray())));
Expand All @@ -460,46 +468,85 @@ private Pointer ffiQueryParams(AbstractBindableTableScan tableScan) {
}

private @Nullable Pointer ffiIndexPredicates(GraphLogicalSource source) {
ImmutableList<RexNode> filters = source.getFilters();
if (ObjectUtils.isEmpty(filters)) return null;
// decomposed by OR
List<RexNode> disJunctions = RelOptUtil.disjunctions(filters.get(0));
List<RexNode> literals = new ArrayList<>();
for (RexNode rexNode : disJunctions) {
if (!isIdEqualsLiteral(rexNode, literals)) {
return null;
}
}
RexNode uniqueKeyFilters = source.getUniqueKeyFilters();
if (uniqueKeyFilters == null) return null;
// 'within' operator in index predicate is unsupported in ir core, here just expand it to
// 'or'
// i.e. '~id within [1, 2]' -> '~id == 1 or ~id == 2'
RexNode expandSearch = RexUtil.expandSearch(this.rexBuilder, null, uniqueKeyFilters);
List<RexNode> disjunctions = RelOptUtil.disjunctions(expandSearch);
Pointer ptrIndex = LIB.initIndexPredicate();
for (RexNode literal : literals) {
FfiProperty.ByValue property = new FfiProperty.ByValue();
property.opt = FfiPropertyOpt.Id;
checkFfiResult(
LIB.orEquivPredicate(ptrIndex, property, Utils.ffiConst((RexLiteral) literal)));
for (RexNode disjunction : disjunctions) {
if (disjunction instanceof RexCall) {
RexCall rexCall = (RexCall) disjunction;
switch (rexCall.getOperator().getKind()) {
case EQUALS:
RexNode left = rexCall.getOperands().get(0);
RexNode right = rexCall.getOperands().get(1);
if (left instanceof RexGraphVariable
&& (right instanceof RexLiteral
|| right instanceof RexDynamicParam)) {
LIB.orEquivPredicate(
ptrIndex,
getFfiProperty(((RexGraphVariable) left).getProperty()),
getFfiConst(right));
break;
} else if (right instanceof RexGraphVariable
&& (left instanceof RexLiteral
|| left instanceof RexDynamicParam)) {
LIB.orEquivPredicate(
ptrIndex,
getFfiProperty(((RexGraphVariable) right).getProperty()),
getFfiConst(left));
break;
}
default:
throw new IllegalArgumentException(
"can not convert unique key filter pattern="
+ rexCall
+ " to ir core index predicate");
}
} else {
throw new IllegalArgumentException(
"invalid unique key filter pattern=" + disjunction);
}
}
// remove index predicates from filter conditions
source.setFilters(ImmutableList.of());
return ptrIndex;
}

// i.e. a.~id == 10
private boolean isIdEqualsLiteral(RexNode rexNode, List<RexNode> literal) {
if (rexNode.getKind() != SqlKind.EQUALS) return false;
List<RexNode> operands = ((RexCall) rexNode).getOperands();
return isGlobalId(operands.get(0)) && isLiteral(operands.get(1), literal)
|| isGlobalId(operands.get(1)) && isLiteral(operands.get(0), literal);
private FfiProperty.ByValue getFfiProperty(GraphProperty property) {
Preconditions.checkArgument(property != null, "unique key should not be null");
FfiProperty.ByValue ffiProperty = new FfiProperty.ByValue();
switch (property.getOpt()) {
case ID:
ffiProperty.opt = FfiPropertyOpt.Id;
break;
case KEY:
ffiProperty.opt = FfiPropertyOpt.Key;
ffiProperty.key = getFfiNameOrId(property.getKey());
break;
default:
throw new IllegalArgumentException(
"can not convert property=" + property + " to ffi property");
}
return ffiProperty;
}

// i.e. a.~id
private boolean isGlobalId(RexNode rexNode) {
return rexNode instanceof RexGraphVariable
&& ((RexGraphVariable) rexNode).getProperty().getOpt() == GraphProperty.Opt.ID;
private FfiNameOrId.ByValue getFfiNameOrId(GraphNameOrId nameOrId) {
switch (nameOrId.getOpt()) {
case NAME:
return ArgUtils.asNameOrId(nameOrId.getName());
case ID:
default:
return ArgUtils.asNameOrId(nameOrId.getId());
}
}

private boolean isLiteral(RexNode rexNode, List<RexNode> literal) {
boolean isLiteral = rexNode.getKind() == SqlKind.LITERAL;
if (isLiteral) literal.add(rexNode);
return isLiteral;
private FfiConst.ByValue getFfiConst(RexNode rexNode) {
if (rexNode instanceof RexLiteral) {
return Utils.ffiConst((RexLiteral) rexNode);
}
throw new IllegalArgumentException("cannot convert rexNode=" + rexNode + " to ffi const");
}

private List<Integer> range(RexNode offset, RexNode fetch) {
Expand Down Expand Up @@ -592,7 +639,8 @@ private void addFilterToFfiBinder(Pointer ptrSentence, AbstractBindableTableScan
List<RexNode> filters = tableScan.getFilters();
if (ObjectUtils.isNotEmpty(filters)) {
OuterExpression.Expression exprProto =
filters.get(0).accept(new RexToProtoConverter(true, isColumnId));
filters.get(0)
.accept(new RexToProtoConverter(true, isColumnId, this.rexBuilder));
Pointer ptrFilter = LIB.initSelectOperator();
checkFfiResult(
LIB.setSelectPredicatePb(
Expand Down
Loading

0 comments on commit 0d24045

Please sign in to comment.