From 35a860aaf3ba46e30f85cefbc941e64288f0eb3d Mon Sep 17 00:00:00 2001 From: zml1206 Date: Tue, 30 Jul 2024 13:10:28 +0800 Subject: [PATCH] [VL] Support Sum(Literal) --- .../org/apache/gluten/backendsapi/velox/VeloxBackend.scala | 6 +++--- .../gluten/execution/ScalarFunctionsValidateSuite.scala | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 990991c71660..0eb6126876b5 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -438,13 +438,13 @@ object VeloxBackendSettings extends BackendSettingsApi { plan match { case exec: HashAggregateExec if exec.aggregateExpressions.nonEmpty => - // Check Sum(1) or Count(1). + // Check Sum(Literal) or Count(Literal). exec.aggregateExpressions.forall( expression => { val aggFunction = expression.aggregateFunction aggFunction match { - case _: Sum | _: Count => - aggFunction.children.size == 1 && aggFunction.children.head.equals(Literal(1)) + case Sum(Literal(_, _), _) => true + case Count(Seq(Literal(_, _))) => true case _ => false } }) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index 43d54fb62e4b..346ec898ae5e 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -754,6 +754,12 @@ abstract class ScalarFunctionsValidateSuite extends FunctionsValidateTest { } } + test("Test sum/count function") { + runQueryAndCompare("""SELECT sum(2),count(2) from lineitem""".stripMargin) { + checkGlutenOperatorMatch[BatchScanExecTransformer] + } + } + test("Test spark_partition_id function") { runQueryAndCompare("""SELECT spark_partition_id(), l_orderkey | from lineitem limit 100""".stripMargin) {