diff --git a/docs/interactive_engine/neo4j/supported_cypher.md b/docs/interactive_engine/neo4j/supported_cypher.md index fc502a149d54..eddd9fccc266 100644 --- a/docs/interactive_engine/neo4j/supported_cypher.md +++ b/docs/interactive_engine/neo4j/supported_cypher.md @@ -101,7 +101,9 @@ Note that some Aggregator operators, such as `max()`, we listed here are impleme | Labels | Get label name of a vertex type | labels() | labels() | | | | Type | Get label name of an edge type | type() | type() | | | | Extract | Get interval value from a temporal type | \.\ | \.\ | | | - +| Starts With | Perform case-sensitive matching on the beginning of a string | STARTS WITH | STARTS WITH | | | +| Ends With | Perform case-sensitive matching on the ending of a string | ENDS WITH | ENDS WITH | | | +| Contains | Perform case-sensitive matching regardless of location within a string | CONTAINS | CONTAINS | | | ## Clause diff --git a/interactive_engine/compiler/ir_exprimental_ci.sh b/interactive_engine/compiler/ir_exprimental_ci.sh index 0c4ab17d3ee7..83f509da481e 100755 --- a/interactive_engine/compiler/ir_exprimental_ci.sh +++ b/interactive_engine/compiler/ir_exprimental_ci.sh @@ -24,6 +24,7 @@ sleep 5s cd ${base_dir} && make run graph.schema:=../executor/ir/core/resource/movie_schema.json & sleep 10s # run cypher movie tests +export ENGINE_TYPE=pegasus cd ${base_dir} && make cypher_test exit_code=$? # clean service diff --git a/interactive_engine/compiler/src/main/antlr4/CypherGS.g4 b/interactive_engine/compiler/src/main/antlr4/CypherGS.g4 index e70bb541ff72..da2bfb5f3965 100644 --- a/interactive_engine/compiler/src/main/antlr4/CypherGS.g4 +++ b/interactive_engine/compiler/src/main/antlr4/CypherGS.g4 @@ -226,9 +226,17 @@ oC_PartialComparisonExpression | ( '<=' SP? oC_StringListNullPredicateExpression ) | ( '>=' SP? oC_StringListNullPredicateExpression ) ; - oC_StringListNullPredicateExpression - : oC_AddOrSubtractExpression ( oC_NullPredicateExpression )? ; + : oC_AddOrSubtractExpression ( oC_StringPredicateExpression | oC_NullPredicateExpression )* ; + +oC_StringPredicateExpression + : ( ( SP STARTS SP WITH ) | ( SP ENDS SP WITH ) | ( SP CONTAINS ) ) SP? oC_AddOrSubtractExpression ; + +STARTS : ( 'S' | 's' ) ( 'T' | 't' ) ( 'A' | 'a' ) ( 'R' | 'r' ) ( 'T' | 't' ) ( 'S' | 's' ) ; + +ENDS : ( 'E' | 'e' ) ( 'N' | 'n' ) ( 'D' | 'd' ) ( 'S' | 's' ) ; + +CONTAINS : ( 'C' | 'c' ) ( 'O' | 'o' ) ( 'N' | 'n' ) ( 'T' | 't' ) ( 'A' | 'a' ) ( 'I' | 'i' ) ( 'N' | 'n' ) ( 'S' | 's' ) ; oC_NullPredicateExpression : ( SP IS SP NULL ) diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java index 25c7df05fdd6..736b49ce18bd 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java @@ -224,6 +224,10 @@ public static final OuterExpression.ExprOpr protoOperator(SqlOperator operator) return OuterExpression.ExprOpr.newBuilder() .setLogical(OuterExpression.Logical.WITHIN) .build(); + case POSIX_REGEX_CASE_SENSITIVE: + return OuterExpression.ExprOpr.newBuilder() + .setLogical(OuterExpression.Logical.REGEX) + .build(); default: throw new UnsupportedOperationException( "operator type=" diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java index 010da648d2bf..276dad1425f5 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java @@ -658,7 +658,8 @@ private boolean isCurrentSupported(SqlOperator operator) { || sqlKind == SqlKind.IS_NULL || sqlKind == SqlKind.IS_NOT_NULL || sqlKind == SqlKind.EXTRACT - || sqlKind == SqlKind.SEARCH; + || sqlKind == SqlKind.SEARCH + || sqlKind == SqlKind.POSIX_REGEX_CASE_SENSITIVE; } @Override diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphStdOperatorTable.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphStdOperatorTable.java index 1959fa4571fb..f48245b067c2 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphStdOperatorTable.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphStdOperatorTable.java @@ -20,6 +20,7 @@ import com.alibaba.graphscope.common.ir.rex.operator.CaseOperator; import org.apache.calcite.sql.*; +import org.apache.calcite.sql.fun.ExtSqlPosixRegexOperator; import org.apache.calcite.sql.fun.SqlMonotonicBinaryOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.*; @@ -226,4 +227,8 @@ public static final SqlFunction USER_DEFINED_PROCEDURE(StoredProcedureMeta meta) null, GraphOperandTypes.INTERVALINTERVAL_INTERVALDATETIME, SqlFunctionCategory.SYSTEM); + + public static final SqlOperator POSIX_REGEX_CASE_SENSITIVE = + new ExtSqlPosixRegexOperator( + "POSIX REGEX CASE SENSITIVE", SqlKind.POSIX_REGEX_CASE_SENSITIVE, true, false); } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ExpressionVisitor.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ExpressionVisitor.java index f14b13a0742f..b21e26a9fbbc 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ExpressionVisitor.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ExpressionVisitor.java @@ -33,6 +33,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; +import org.antlr.v4.runtime.tree.ParseTree; import org.antlr.v4.runtime.tree.TerminalNode; import org.apache.calcite.avatica.util.TimeUnit; import org.apache.calcite.rel.RelNode; @@ -44,7 +45,9 @@ import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.NlsString; import org.apache.commons.lang3.ObjectUtils; import org.checkerframework.checker.nullness.qual.Nullable; @@ -128,18 +131,72 @@ public ExprVisitorResult visitOC_StringListNullPredicateExpression( CypherGSParser.OC_StringListNullPredicateExpressionContext ctx) { ExprVisitorResult operand = visitOC_AddOrSubtractExpression(ctx.oC_AddOrSubtractExpression()); - List operators = Lists.newArrayList(); - CypherGSParser.OC_NullPredicateExpressionContext nullCtx = ctx.oC_NullPredicateExpression(); - if (nullCtx != null) { - if (nullCtx.IS() != null && nullCtx.NOT() != null && nullCtx.NULL() != null) { - operators.add(GraphStdOperatorTable.IS_NOT_NULL); - } else if (nullCtx.IS() != null && nullCtx.NULL() != null) { - operators.add(GraphStdOperatorTable.IS_NULL); + Iterator i$ = ctx.children.iterator(); + while (i$.hasNext()) { + ParseTree o = (ParseTree) i$.next(); + if (o == null) continue; + if (CypherGSParser.OC_NullPredicateExpressionContext.class.isInstance(o)) { + operand = + visitOC_NullPredicateExpression( + operand, (CypherGSParser.OC_NullPredicateExpressionContext) o); + } else if (CypherGSParser.OC_StringPredicateExpressionContext.class.isInstance(o)) { + operand = + visitOC_StringPredicateExpression( + operand, (CypherGSParser.OC_StringPredicateExpressionContext) o); } } + return operand; + } + + private ExprVisitorResult visitOC_NullPredicateExpression( + ExprVisitorResult operand, CypherGSParser.OC_NullPredicateExpressionContext nullCtx) { + List operators = Lists.newArrayList(); + if (nullCtx.IS() != null && nullCtx.NOT() != null && nullCtx.NULL() != null) { + operators.add(GraphStdOperatorTable.IS_NOT_NULL); + } else if (nullCtx.IS() != null && nullCtx.NULL() != null) { + operators.add(GraphStdOperatorTable.IS_NULL); + } else { + throw new IllegalArgumentException( + "unknown null predicate expression: " + nullCtx.getText()); + } return unaryCall(operators, operand); } + private ExprVisitorResult visitOC_StringPredicateExpression( + ExprVisitorResult operand, + CypherGSParser.OC_StringPredicateExpressionContext stringCtx) { + ExprVisitorResult rightRes = + visitOC_AddOrSubtractExpression(stringCtx.oC_AddOrSubtractExpression()); + RexNode rightExpr = rightRes.getExpr(); + // the right operand should be a string literal + Preconditions.checkArgument( + rightExpr.getKind() == SqlKind.LITERAL + && rightExpr.getType().getFamily() == SqlTypeFamily.CHARACTER, + "the right operand of string predicate expression should be a string literal"); + String value = ((RexLiteral) rightExpr).getValueAs(NlsString.class).getValue(); + StringBuilder regexPattern = new StringBuilder(); + if (stringCtx.STARTS() != null) { + regexPattern.append(value); + regexPattern.append(".*"); + } else if (stringCtx.ENDS() != null) { + regexPattern.append(".*"); + regexPattern.append(value); + } else if (stringCtx.CONTAINS() != null) { + regexPattern.append(".*"); + regexPattern.append(value); + regexPattern.append(".*"); + } else { + throw new IllegalArgumentException( + "unknown string predicate expression: " + stringCtx.getText()); + } + return binaryCall( + GraphStdOperatorTable.POSIX_REGEX_CASE_SENSITIVE, + ImmutableList.of( + operand, + new ExprVisitorResult( + rightRes.getAggCalls(), builder.literal(regexPattern.toString())))); + } + @Override public ExprVisitorResult visitOC_AddOrSubtractExpression( CypherGSParser.OC_AddOrSubtractExpressionContext ctx) { diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/movie/MovieQueries.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/movie/MovieQueries.java index 9fd1e09583db..f718fa4c9446 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/movie/MovieQueries.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/movie/MovieQueries.java @@ -301,4 +301,22 @@ public static QueryContext get_movie_query15_test() { + " \"Tom Cruise\"}>"); return new QueryContext(query, expected); } + + public static QueryContext get_movie_query16_test() { + String query = "Match (n:Movie {id: 0}) Where n.title starts with 'The' Return n.title;"; + List expected = Arrays.asList("Record<{title: \"The Matrix\"}>"); + return new QueryContext(query, expected); + } + + public static QueryContext get_movie_query17_test() { + String query = "Match (n:Movie {id: 0}) Where n.title ends with 'Matrix' Return n.title;"; + List expected = Arrays.asList("Record<{title: \"The Matrix\"}>"); + return new QueryContext(query, expected); + } + + public static QueryContext get_movie_query18_test() { + String query = "Match (n:Movie {id: 0}) Where n.title contains 'The' Return n.title;"; + List expected = Arrays.asList("Record<{title: \"The Matrix\"}>"); + return new QueryContext(query, expected); + } } diff --git a/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/fun/ExtSqlPosixRegexOperator.java b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/fun/ExtSqlPosixRegexOperator.java new file mode 100644 index 000000000000..f7d9551861f8 --- /dev/null +++ b/interactive_engine/compiler/src/main/java/org/apache/calcite/sql/fun/ExtSqlPosixRegexOperator.java @@ -0,0 +1,70 @@ +/* + * Copyright 2020 Alibaba Group Holding Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.calcite.sql.fun; + +import com.alibaba.graphscope.common.ir.rex.RexCallBinding; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.util.Static; +import org.apache.calcite.util.Util; + +/** + * The operator is used for regex match for string values, i.e a.name like '%marko' in a sql expression. + * The original implementation will check operand types by {@link org.apache.calcite.sql.SqlCall}, which is a structure in sql parser phase. + * Here we override the interface to check types by {@link org.apache.calcite.rex.RexCall} which represents an algebra relation. + */ +public class ExtSqlPosixRegexOperator extends SqlPosixRegexOperator { + public ExtSqlPosixRegexOperator( + String name, SqlKind kind, boolean caseSensitive, boolean negated) { + super(name, kind, caseSensitive, negated); + } + + @Override + public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + int operandCount = callBinding.getOperandCount(); + if (operandCount != 2) { + throw new AssertionError( + "Unexpected number of args to " + callBinding.getCall() + ": " + operandCount); + } else { + RelDataType op1Type = callBinding.getOperandType(0); + RelDataType op2Type = callBinding.getOperandType(1); + if (!SqlTypeUtil.isComparable(op1Type, op2Type)) { + throw new AssertionError( + "Incompatible first two operand types " + op1Type + " and " + op2Type); + } else { + if (!SqlTypeUtil.isCharTypeComparable(callBinding.collectOperandTypes())) { + if (throwOnFailure) { + String msg = + String.join( + ", ", + Util.transform( + ((RexCallBinding) callBinding).getRexOperands(), + String::valueOf)); + throw callBinding.newError(Static.RESOURCE.operandNotComparable(msg)); + } else { + return false; + } + } else { + return true; + } + } + } + } +} diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/ExpressionTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/ExpressionTest.java index a0f0357dead8..4f84562a21d4 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/ExpressionTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/common/ir/ExpressionTest.java @@ -182,6 +182,19 @@ public void dynamic_param_type_test() { Assert.assertEquals(SqlTypeName.INTEGER, plus.getType().getSqlTypeName()); } + @Test + public void posix_regex_test() { + RexNode regex = + builder.source(mockSourceConfig(null)) + .call( + GraphStdOperatorTable.POSIX_REGEX_CASE_SENSITIVE, + builder.variable(null, "name"), + builder.literal("^marko")); + Assert.assertEquals(SqlTypeName.BOOLEAN, regex.getType().getSqlTypeName()); + Assert.assertEquals( + "POSIX REGEX CASE SENSITIVE(DEFAULT.name, _UTF-8'^marko')", regex.toString()); + } + private SourceConfig mockSourceConfig(String alias) { return new SourceConfig( GraphOpt.Source.VERTEX, new LabelConfig(false).addLabel("person"), alias); diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java index ca3c9c3154e4..1d4a8e3fa496 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java @@ -311,4 +311,40 @@ public void match_16_test() { + "], matchOpt=[INNER])", node.explain().trim()); } + + @Test + public void match_17_test() { + RelNode node = + Utils.eval("Match (a:person) Where a.name starts with 'marko' Return a").build(); + Assert.assertEquals( + "GraphLogicalProject(a=[a], isAppend=[false])\n" + + " GraphLogicalSource(tableConfig=[{isAll=false, tables=[person]}]," + + " alias=[a], fusedFilter=[[POSIX REGEX CASE SENSITIVE(DEFAULT.name," + + " _UTF-8'marko.*')]], opt=[VERTEX])", + node.explain().trim()); + } + + @Test + public void match_18_test() { + RelNode node = + Utils.eval("Match (a:person) Where a.name ends with 'marko' Return a").build(); + Assert.assertEquals( + "GraphLogicalProject(a=[a], isAppend=[false])\n" + + " GraphLogicalSource(tableConfig=[{isAll=false, tables=[person]}]," + + " alias=[a], fusedFilter=[[POSIX REGEX CASE SENSITIVE(DEFAULT.name," + + " _UTF-8'.*marko')]], opt=[VERTEX])", + node.explain().trim()); + } + + @Test + public void match_19_test() { + RelNode node = + Utils.eval("Match (a:person) Where a.name contains 'marko' Return a").build(); + Assert.assertEquals( + "GraphLogicalProject(a=[a], isAppend=[false])\n" + + " GraphLogicalSource(tableConfig=[{isAll=false, tables=[person]}]," + + " alias=[a], fusedFilter=[[POSIX REGEX CASE SENSITIVE(DEFAULT.name," + + " _UTF-8'.*marko.*')]], opt=[VERTEX])", + node.explain().trim()); + } } diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/movie/MovieTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/movie/MovieTest.java index c932fc8ed086..7d8e0a10e275 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/movie/MovieTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/movie/MovieTest.java @@ -16,6 +16,8 @@ package com.alibaba.graphscope.cypher.integration.movie; +import static org.junit.Assume.assumeTrue; + import com.alibaba.graphscope.cypher.integration.suite.QueryContext; import com.alibaba.graphscope.cypher.integration.suite.movie.MovieQueries; @@ -122,6 +124,30 @@ public void run_movie_query15_test() { Assert.assertEquals(testQuery.getExpectedResult().toString(), result.list().toString()); } + @Test + public void run_movie_query16_test() { + assumeTrue("pegasus".equals(System.getenv("ENGINE_TYPE"))); + QueryContext testQuery = MovieQueries.get_movie_query16_test(); + Result result = session.run(testQuery.getQuery()); + Assert.assertEquals(testQuery.getExpectedResult().toString(), result.list().toString()); + } + + @Test + public void run_movie_query17_test() { + assumeTrue("pegasus".equals(System.getenv("ENGINE_TYPE"))); + QueryContext testQuery = MovieQueries.get_movie_query17_test(); + Result result = session.run(testQuery.getQuery()); + Assert.assertEquals(testQuery.getExpectedResult().toString(), result.list().toString()); + } + + @Test + public void run_movie_query18_test() { + assumeTrue("pegasus".equals(System.getenv("ENGINE_TYPE"))); + QueryContext testQuery = MovieQueries.get_movie_query18_test(); + Result result = session.run(testQuery.getQuery()); + Assert.assertEquals(testQuery.getExpectedResult().toString(), result.list().toString()); + } + @AfterClass public static void afterClass() { if (session != null) { diff --git a/interactive_engine/executor/ir/common/src/expr_parse/mod.rs b/interactive_engine/executor/ir/common/src/expr_parse/mod.rs index 1d559d0b5301..f9c6074c5821 100644 --- a/interactive_engine/executor/ir/common/src/expr_parse/mod.rs +++ b/interactive_engine/executor/ir/common/src/expr_parse/mod.rs @@ -157,7 +157,8 @@ impl ExprToken for pb::ExprOpr { | pb::Logical::Without | pb::Logical::Startswith | pb::Logical::Endswith - | pb::Logical::Isnull => 80, + | pb::Logical::Isnull + | pb::Logical::Regex => 80, pb::Logical::And => 75, pb::Logical::Or => 70, pb::Logical::Not => 110, diff --git a/interactive_engine/executor/ir/graph_proxy/Cargo.toml b/interactive_engine/executor/ir/graph_proxy/Cargo.toml index 6775437bf84c..bfcf4a0d55b3 100644 --- a/interactive_engine/executor/ir/graph_proxy/Cargo.toml +++ b/interactive_engine/executor/ir/graph_proxy/Cargo.toml @@ -18,6 +18,7 @@ pegasus_common = { path = "../../engine/pegasus/common" } ahash = "0.8" rand = "0.8.5" chrono = "0.4" +regex = "1.10" [features] default = [] diff --git a/interactive_engine/executor/ir/graph_proxy/src/utils/expr/eval.rs b/interactive_engine/executor/ir/graph_proxy/src/utils/expr/eval.rs index 3fe3320b7dec..f33f351f7355 100644 --- a/interactive_engine/executor/ir/graph_proxy/src/utils/expr/eval.rs +++ b/interactive_engine/executor/ir/graph_proxy/src/utils/expr/eval.rs @@ -247,6 +247,10 @@ pub(crate) fn apply_logical<'a>( .as_str()? .ends_with(b.as_str()?.as_ref()) .into()), + Regex => { + let regex = regex::Regex::new(b.as_str()?.as_ref())?; + Ok(regex.is_match(a.as_str()?.as_ref()).into()) + } Not => unreachable!(), Isnull => unreachable!(), } @@ -1233,4 +1237,65 @@ mod tests { assert_eq!(eval.eval::<_, Vertices>(Some(&ctxt)).unwrap(), expected); } } + + fn gen_regex_expression(to_match: &str, pattern: &str) -> common_pb::Expression { + let mut regex_expr = common_pb::Expression { operators: vec![] }; + let left = common_pb::ExprOpr { + node_type: None, + item: Some(common_pb::expr_opr::Item::Const(common_pb::Value { + item: Some(common_pb::value::Item::Str(to_match.to_string())), + })), + }; + regex_expr.operators.push(left); + let regex_opr = common_pb::ExprOpr { + node_type: None, + item: Some(common_pb::expr_opr::Item::Logical(common_pb::Logical::Regex as i32)), + }; + regex_expr.operators.push(regex_opr); + let right = common_pb::ExprOpr { + node_type: None, + item: Some(common_pb::expr_opr::Item::Const(common_pb::Value { + item: Some(common_pb::value::Item::Str(pattern.to_string())), + })), + }; + regex_expr.operators.push(right); + regex_expr + } + + #[test] + fn test_eval_regex() { + // TODO: the parser does not support escape characters in regex well yet. + // So use gen_regex_expression() to help generate expression + let cases: Vec<(&str, &str)> = vec![ + ("Josh", r"^J"), // startWith, true + ("Josh", r"J.*"), // true + ("Josh", r"h$"), // endWith, true + ("Josh", r".*h"), // true + ("Josh", r"os"), // true + ("Josh", r"A.*"), // false + ("Josh", r".*A"), // false + ("Josh", r"ab"), // false + ("Josh", r"Josh.+"), // false + ("2010-03-14", r"^\d{4}-\d{2}-\d{2}$"), // true + (r"I categorically deny having triskaidekaphobia.", r"\b\w{13}\b"), //true + ]; + let expected: Vec = vec![ + object!(true), + object!(true), + object!(true), + object!(true), + object!(true), + object!(false), + object!(false), + object!(false), + object!(false), + object!(true), + object!(true), + ]; + + for ((to_match, pattern), expected) in cases.into_iter().zip(expected.into_iter()) { + let eval = Evaluator::try_from(gen_regex_expression(to_match, pattern)).unwrap(); + assert_eq!(eval.eval::<(), NoneContext>(None).unwrap(), expected); + } + } } diff --git a/interactive_engine/executor/ir/graph_proxy/src/utils/expr/eval_pred.rs b/interactive_engine/executor/ir/graph_proxy/src/utils/expr/eval_pred.rs index 450ec4eaedf8..82d3c009a5ce 100644 --- a/interactive_engine/executor/ir/graph_proxy/src/utils/expr/eval_pred.rs +++ b/interactive_engine/executor/ir/graph_proxy/src/utils/expr/eval_pred.rs @@ -334,7 +334,8 @@ impl EvalPred for Predicate { | Logical::Within | Logical::Without | Logical::Startswith - | Logical::Endswith => Ok(apply_logical( + | Logical::Endswith + | Logical::Regex => Ok(apply_logical( &self.cmp, self.left.eval(context)?.as_borrow_object(), Some(self.right.eval(context)?.as_borrow_object()), @@ -449,7 +450,8 @@ fn process_predicates( | Logical::Without | Logical::Startswith | Logical::Endswith - | Logical::Isnull => partial.cmp(logical)?, + | Logical::Isnull + | Logical::Regex => partial.cmp(logical)?, Logical::Not => is_not = true, Logical::And | Logical::Or => { predicates = predicates.merge_partial(curr_cmp, partial, is_not)?; @@ -592,6 +594,7 @@ mod tests { NameOrId::from("hobbies".to_string()), vec!["football".to_string(), "guitar".to_string()].into(), ), + (NameOrId::from("str_birthday".to_string()), "1990-04-16".to_string().into()), ] .into_iter() .collect(); @@ -955,4 +958,53 @@ mod tests { ); } } + + fn gen_regex_expression(to_match: &str, pattern: &str) -> common_pb::Expression { + let mut regex_expr = str_to_expr_pb(to_match.to_string()).unwrap(); + let regex_opr = common_pb::ExprOpr { + node_type: None, + item: Some(common_pb::expr_opr::Item::Logical(common_pb::Logical::Regex as i32)), + }; + regex_expr.operators.push(regex_opr); + let right = common_pb::ExprOpr { + node_type: None, + item: Some(common_pb::expr_opr::Item::Const(common_pb::Value { + item: Some(common_pb::value::Item::Str(pattern.to_string())), + })), + }; + regex_expr.operators.push(right); + regex_expr + } + + #[test] + fn test_eval_predicates_regex() { + // [v0: id = 1, label = 9, age = 31, name = John, birthday = 19900416, hobbies = [football, guitar]] + // [v1: id = 2, label = 11, age = 26, name = Jimmy, birthday = 19950816] + let ctxt = prepare_context(); + + // TODO: the parser does not support escape characters in regex well yet. + // So use gen_regex_expression() to help generate expression + let cases: Vec<(&str, &str)> = vec![ + ("@0.name", r"^J"), // startWith, true + ("@0.name", r"J.*"), // true + ("@0.name", r"n$"), // endWith, true + ("@0.name", r".*n"), // true + ("@0.name", r"oh"), // true + ("@0.name", r"A.*"), // false + ("@0.name", r".*A"), // false + ("@0.name", r"ab"), // false + ("@0.name", r"John.+"), // false + ("@0.str_birthday", r"^\d{4}-\d{2}-\d{2}$"), // true + ]; + let expected: Vec = vec![true, true, true, true, true, false, false, false, false, true]; + + for ((to_match, pattern), expected) in cases.into_iter().zip(expected.into_iter()) { + let eval = PEvaluator::try_from(gen_regex_expression(to_match, pattern)).unwrap(); + assert_eq!( + eval.eval_bool::<_, Vertices>(Some(&ctxt)) + .unwrap(), + expected + ); + } + } } diff --git a/interactive_engine/executor/ir/graph_proxy/src/utils/expr/mod.rs b/interactive_engine/executor/ir/graph_proxy/src/utils/expr/mod.rs index b6dd71ae77b6..31de55583db3 100644 --- a/interactive_engine/executor/ir/graph_proxy/src/utils/expr/mod.rs +++ b/interactive_engine/executor/ir/graph_proxy/src/utils/expr/mod.rs @@ -44,6 +44,8 @@ pub enum ExprEvalError { UnexpectedDataType(OperatorDesc), /// Get ``None` from `Context` GetNoneFromContext, + /// Regex Error + RegexError(regex::Error), /// Unsupported Unsupported(String), /// Other unknown errors that is converted from a error description @@ -68,6 +70,7 @@ impl Display for ExprEvalError { GetNoneFromContext => write!(f, "get `None` from `Context`"), Unsupported(e) => write!(f, "unsupported: {}", e), OtherErr(e) => write!(f, "parse error {}", e), + RegexError(e) => write!(f, "regex error {}", e), } } } @@ -85,3 +88,9 @@ impl From<&str> for ExprEvalError { Self::OtherErr(str.to_string()) } } + +impl From for ExprEvalError { + fn from(error: regex::Error) -> Self { + Self::RegexError(error) + } +} diff --git a/interactive_engine/executor/ir/proto/expr.proto b/interactive_engine/executor/ir/proto/expr.proto index afb6fe8c845f..022d1daffadb 100644 --- a/interactive_engine/executor/ir/proto/expr.proto +++ b/interactive_engine/executor/ir/proto/expr.proto @@ -79,6 +79,8 @@ enum Logical { NOT = 12; // A unary logical isnull operator ISNULL = 13; + // A binary operator to verify whether a string matches a regular expression + REGEX = 14; } enum Arithmetic {