diff --git a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphFamilyOperandTypeChecker.java b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphFamilyOperandTypeChecker.java index c3eb912d08c8..182c1389c596 100644 --- a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphFamilyOperandTypeChecker.java +++ b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphFamilyOperandTypeChecker.java @@ -23,10 +23,12 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.*; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.validate.implicit.TypeCoercion; import org.apache.calcite.util.Litmus; +import java.util.Collection; import java.util.List; import java.util.function.Predicate; @@ -159,7 +161,7 @@ private boolean checkSingleOperandType( return true; } - if (!family.getTypeNames().contains(typeName)) { + if (!getAllowedTypeNames(family, iFormalOperand).contains(typeName)) { if (throwOnFailure) { throw callBinding.newValidationSignatureError(); } @@ -168,6 +170,11 @@ private boolean checkSingleOperandType( return true; } + protected Collection getAllowedTypeNames( + SqlTypeFamily family, int iFormalOperand) { + return family.getTypeNames(); + } + private boolean isNullLiteral(RexNode node) { if (node instanceof RexLiteral) { RexLiteral literal = (RexLiteral) node; diff --git a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandMetaDataImpl.java b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandMetaDataImpl.java index 6e7f20874a19..d162c50a7f48 100644 --- a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandMetaDataImpl.java +++ b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/type/GraphOperandMetaDataImpl.java @@ -16,15 +16,23 @@ package org.apache.calcite.sql.type; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + import org.apache.calcite.linq4j.function.Functions; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlUtil; +import org.checkerframework.checker.nullness.qual.Nullable; +import java.util.Collection; import java.util.List; import java.util.Objects; import java.util.function.Function; import java.util.function.IntFunction; import java.util.function.Predicate; +import java.util.stream.Collectors; public class GraphOperandMetaDataImpl extends GraphFamilyOperandTypeChecker implements SqlOperandMetadata { @@ -33,7 +41,7 @@ public class GraphOperandMetaDataImpl extends GraphFamilyOperandTypeChecker GraphOperandMetaDataImpl( List families, - Function> paramTypesFactory, + Function<@Nullable RelDataTypeFactory, List> paramTypesFactory, IntFunction paramNameFn, Predicate optional) { super(families, optional); @@ -41,18 +49,41 @@ public class GraphOperandMetaDataImpl extends GraphFamilyOperandTypeChecker this.paramNameFn = paramNameFn; } + @Override + protected Collection getAllowedTypeNames( + SqlTypeFamily family, int iFormalOperand) { + List paramsAllowedTypes = paramTypes(null); + Preconditions.checkArgument( + paramsAllowedTypes.size() > iFormalOperand, + "cannot find allowed type for type index=" + + iFormalOperand + + " from the allowed types list=" + + paramsAllowedTypes); + return ImmutableList.of(paramsAllowedTypes.get(iFormalOperand).getSqlTypeName()); + } + @Override public boolean isFixedParameters() { return true; } @Override - public List paramTypes(RelDataTypeFactory typeFactory) { - return (List) this.paramTypesFactory.apply(typeFactory); + public List paramTypes(@Nullable RelDataTypeFactory typeFactory) { + return this.paramTypesFactory.apply(typeFactory); } @Override public List paramNames() { return Functions.generate(this.families.size(), this.paramNameFn); } + + @Override + public String getAllowedSignatures(SqlOperator op, String opName) { + return SqlUtil.getAliasedSignature( + op, + opName, + paramTypes(null).stream() + .map(k -> k.getSqlTypeName()) + .collect(Collectors.toList())); + } } diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/CallProcedureTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/CallProcedureTest.java index 9070f20b3975..8efe66fb205b 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/CallProcedureTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/CallProcedureTest.java @@ -18,15 +18,31 @@ import com.alibaba.graphscope.common.ir.tools.LogicalPlan; +import org.apache.calcite.runtime.CalciteException; import org.junit.Assert; import org.junit.Test; public class CallProcedureTest { @Test - public void match_1_test() { - LogicalPlan logicalPlan = Utils.evalLogicalPlan("Call ldbc_ic2(10l, 20120112l)"); + public void procedure_1_test() { + LogicalPlan logicalPlan = Utils.evalLogicalPlan("Call ldbc_ic2(10, 20120112l)"); Assert.assertEquals("ldbc_ic2(10:BIGINT, 20120112:BIGINT)", logicalPlan.explain().trim()); Assert.assertEquals( "RecordType(CHAR(1) name)", logicalPlan.getProcedureCall().getType().toString()); } + + // test procedure with invalid parameter types + @Test + public void procedure_2_test() { + try { + Utils.evalLogicalPlan("Call ldbc_ic2(10, 20120112l)"); + } catch (CalciteException e) { + Assert.assertEquals( + "Cannot apply ldbc_ic2 to arguments of type 'ldbc_ic2(, )'." + + " Supported form(s): 'ldbc_ic2(, )'", + e.getMessage()); + return; + } + Assert.fail(); + } }