From 7998fb7ea7643974deac540e88d7839b7d1d2a15 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Sun, 16 Jun 2024 18:20:57 +0800 Subject: [PATCH] support array_sort --- .../gluten/execution/TestOperator.scala | 13 +++++++++++++ .../functions/RegistrationAllFunctions.cc | 1 + .../expression/ExpressionConverter.scala | 19 +++++++++++++++++++ .../expression/ExpressionMappings.scala | 1 + .../gluten/expression/ExpressionNames.scala | 1 + 5 files changed, 35 insertions(+) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index a892b6f313a4..f65ff2d76889 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -1872,4 +1872,17 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla } } } + + test("array_sort") { + runQueryAndCompare(""" + |select array_sort(collect_list(l_orderkey)) + |from lineitem + |where l_partkey in (1552, 674, 1062) + |group by l_partkey + |""".stripMargin) { + df => + val op = collect(df.queryExecution.executedPlan) { case p: ProjectExecTransformer => p } + assert(op.size == 2) + } + } } diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc b/cpp/velox/operators/functions/RegistrationAllFunctions.cc index b827690d1cdf..ac34fa398aa2 100644 --- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc +++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc @@ -72,6 +72,7 @@ void registerFunctionOverwrite() { velox::functions::registerBinaryIntegral({"check_subtract"}); velox::functions::registerBinaryIntegral({"check_multiply"}); velox::functions::registerBinaryIntegral({"check_divide"}); + velox::functions::prestosql::registerArrayFunctions("presto_"); } } // namespace diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 464bbbfd002c..965ee0adede7 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -656,6 +656,25 @@ object ExpressionConverter extends SQLConfHelper with Logging { Seq(replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap)), c ) + case ArraySort(argument, function, _) if BackendsApiManager.getBackendName == "velox" => + // scalastyle:off + // We use a special expressions mapping here to translate function name + // of LessThan/GreaterThan/EqualTo to lt/gt/eq, instead of lessthan/greaterthan/equalto. + // Refer to https://github.com/facebookincubator/velox/blob/main/velox/functions/prestosql/SimpleComparisonMatcher.cpp#L79 + // scalastyle:on + val specialExpressionsMap = expressionsMap ++ Map( + (Sig[LessThan]("")).expClass -> "presto_lt", + (Sig[GreaterThan]("")).expClass -> "presto_gt", + (Sig[EqualTo]("")).expClass -> "presto_eq" + ) + GenericExpressionTransformer( + substraitExprName, + Seq( + replaceWithExpressionTransformerInternal(argument, attributeSeq, expressionsMap), + replaceWithExpressionTransformerInternal(function, attributeSeq, specialExpressionsMap) + ), + expr + ) case expr => GenericExpressionTransformer( substraitExprName, diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index 230d91005e9c..48412f1e616c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -231,6 +231,7 @@ object ExpressionMappings { Sig[ArrayMin](ARRAY_MIN), Sig[ArrayJoin](ARRAY_JOIN), Sig[SortArray](SORT_ARRAY), + Sig[ArraySort](ARRAY_SORT), Sig[ArraysOverlap](ARRAYS_OVERLAP), Sig[ArrayPosition](ARRAY_POSITION), Sig[ArrayDistinct](ARRAY_DISTINCT), diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index f817612a1e8d..88c4ff0e7407 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -248,6 +248,7 @@ object ExpressionNames { final val ARRAY_MIN = "array_min" final val ARRAY_JOIN = "array_join" final val SORT_ARRAY = "sort_array" + final val ARRAY_SORT = "presto_array_sort" final val ARRAYS_OVERLAP = "arrays_overlap" final val ARRAY_POSITION = "array_position" final val ARRAY_DISTINCT = "array_distinct"