diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 990991c71660..0eb6126876b5 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -438,13 +438,13 @@ object VeloxBackendSettings extends BackendSettingsApi { plan match { case exec: HashAggregateExec if exec.aggregateExpressions.nonEmpty => - // Check Sum(1) or Count(1). + // Check Sum(Literal) or Count(Literal). exec.aggregateExpressions.forall( expression => { val aggFunction = expression.aggregateFunction aggFunction match { - case _: Sum | _: Count => - aggFunction.children.size == 1 && aggFunction.children.head.equals(Literal(1)) + case Sum(Literal(_, _), _) => true + case Count(Seq(Literal(_, _))) => true case _ => false } }) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index 13ade14b5943..897c1c5f58d5 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -754,6 +754,12 @@ abstract class ScalarFunctionsValidateSuite extends FunctionsValidateTest { } } + test("Test sum/count function") { + runQueryAndCompare("""SELECT sum(2),count(2) from lineitem""".stripMargin) { + checkGlutenOperatorMatch[BatchScanExecTransformer] + } + } + test("Test spark_partition_id function") { runQueryAndCompare("""SELECT spark_partition_id(), l_orderkey | from lineitem limit 100""".stripMargin) {