From b96919df69db44afba1f0cc423f91e9887b1abda Mon Sep 17 00:00:00 2001
From: "joey.ljy" <joey.ljy@alibaba-inc.com>
Date: Wed, 15 Nov 2023 15:09:32 +0800
Subject: [PATCH] Introduce VeloxIntermediateData to adjust type and order

---
 .../HashAggregateExecTransformer.scala        | 126 +++--------------
 .../utils/VeloxIntermediateData.scala         | 131 ++++++++++++++++++
 .../utils/GlutenDecimalUtil.scala             |  40 ------
 3 files changed, 151 insertions(+), 146 deletions(-)
 create mode 100644 backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala
 delete mode 100644 gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala

diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala
index 1b8e10893a89..731b5f75a3db 100644
--- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala
+++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala
@@ -16,7 +16,6 @@
  */
 package io.glutenproject.execution
 
-import io.glutenproject.execution.VeloxAggregateFunctionsBuilder._
 import io.glutenproject.expression._
 import io.glutenproject.expression.ConverterUtils.FunctionConfig
 import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
@@ -24,7 +23,7 @@ import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
 import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
 import io.glutenproject.substrait.extensions.ExtensionBuilder
 import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
-import io.glutenproject.utils.GlutenDecimalUtil
+import io.glutenproject.utils.VeloxIntermediateData
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -80,27 +79,18 @@ case class HashAggregateExecTransformer(
    * @return
    *   extracting needed or not.
    */
-  def extractStructNeeded(): Boolean = {
-    for (expr <- aggregateExpressions) {
-      val aggregateFunction = expr.aggregateFunction
-      aggregateFunction match {
-        case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
-            _: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
-          expr.mode match {
-            case Partial | PartialMerge =>
-              return true
-            case _ =>
-          }
-        case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
-          expr.mode match {
-            case Partial | PartialMerge =>
-              return true
-            case _ =>
-          }
-        case _ =>
-      }
+  private def extractStructNeeded(): Boolean = {
+    aggregateExpressions.exists {
+      expr =>
+        expr.aggregateFunction match {
+          case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
+            expr.mode match {
+              case Partial | PartialMerge => true
+              case _ => false
+            }
+          case _ => false
+        }
     }
-    false
   }
 
   /**
@@ -209,42 +199,7 @@ case class HashAggregateExecTransformer(
     }
   }
 
-  /**
-   * Return the intermediate type node of a partial aggregation in Velox.
-   * @param aggregateFunction
-   *   The aggregation function.
-   * @return
-   *   The type of partial outputs.
-   */
-  private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = {
-    val structTypeNodes = aggregateFunction match {
-      case avg: Average =>
-        ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true) ::
-          ConverterUtils.getTypeNode(LongType, nullable = true) :: Nil
-      case first: First =>
-        ConverterUtils.getTypeNode(first.dataType, nullable = true) ::
-          ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
-      case last: Last =>
-        ConverterUtils.getTypeNode(last.dataType, nullable = true) ::
-          ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
-      case maxMinBy: MaxMinBy =>
-        ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true) ::
-          ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true) :: Nil
-      case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
-        // Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE).
-        veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
-      case _: Corr =>
-        veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
-      case _: CovPopulation | _: CovSample =>
-        veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
-      case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
-        ConverterUtils.getTypeNode(sum.dataType, nullable = true) ::
-          ConverterUtils.getTypeNode(BooleanType, nullable = false) :: Nil
-      case other =>
-        throw new UnsupportedOperationException(s"$other is not supported.")
-    }
-    TypeBuilder.makeStruct(false, structTypeNodes.asJava)
-  }
+
 
   override protected def modeToKeyWord(aggregateMode: AggregateMode): String = {
     super.modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode)
@@ -268,7 +223,7 @@ case class HashAggregateExecTransformer(
             VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
             childrenNodeList,
             modeKeyWord,
-            getIntermediateTypeNode(aggregateFunction))
+            VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
           aggregateNodeList.add(partialNode)
         case PartialMerge =>
           val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
@@ -276,7 +231,7 @@ case class HashAggregateExecTransformer(
               .create(args, aggregateFunction, mixedPartialAndMerge),
             childrenNodeList,
             modeKeyWord,
-            getIntermediateTypeNode(aggregateFunction)
+            VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
           )
           aggregateNodeList.add(aggFunctionNode)
         case Final =>
@@ -356,7 +311,7 @@ case class HashAggregateExecTransformer(
               _: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
             expression.mode match {
               case Partial | PartialMerge =>
-                typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
+                typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
               case Final =>
                 typeNodeList.add(
                   ConverterUtils
@@ -367,7 +322,7 @@ case class HashAggregateExecTransformer(
           case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
             expression.mode match {
               case Partial | PartialMerge =>
-                typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
+                typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
               case Final =>
                 typeNodeList.add(
                   ConverterUtils
@@ -547,7 +502,7 @@ case class HashAggregateExecTransformer(
               // Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk]
               val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
               val veloxInputOrder =
-                VeloxAggregateFunctionsBuilder.veloxCorrIntermediateDataOrder.map(
+                VeloxIntermediateData.veloxCorrIntermediateDataOrder.map(
                   name => sparkCorrOutputAttr.indexOf(name))
               for (order <- veloxInputOrder) {
                 val attr = functionInputAttributes(order)
@@ -590,7 +545,7 @@ case class HashAggregateExecTransformer(
               // Spark's Covar order is [n, xAvg, yAvg, ck]
               val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
               val veloxInputOrder =
-                VeloxAggregateFunctionsBuilder.veloxCovarIntermediateDataOrder.map(
+                VeloxIntermediateData.veloxCovarIntermediateDataOrder.map(
                   name => sparkCorrOutputAttr.indexOf(name))
               for (order <- veloxInputOrder) {
                 val attr = functionInputAttributes(order)
@@ -810,47 +765,6 @@ case class HashAggregateExecTransformer(
 /** An aggregation function builder specifically used by Velox backend. */
 object VeloxAggregateFunctionsBuilder {
 
-  val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
-  val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg")
-
-  val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
-  val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType)
-  val veloxCorrIntermediateTypes: Seq[DataType] =
-    Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType)
-
-  /**
-   * Get the compatible input types for a Velox aggregate function.
-   * @param aggregateFunc:
-   *   the input aggreagate function.
-   * @param forMergeCompanion:
-   *   whether this is a special case to solve mixed aggregation phases.
-   * @return
-   *   the input types of a Velox aggregate function.
-   */
-  private def getInputTypes(
-      aggregateFunc: AggregateFunction,
-      forMergeCompanion: Boolean): Seq[DataType] = {
-    if (!forMergeCompanion) {
-      return aggregateFunc.children.map(_.dataType)
-    }
-    aggregateFunc match {
-      case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
-        Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray))
-      case _: CovPopulation | _: CovSample =>
-        Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray))
-      case _: Corr =>
-        Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray))
-      case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
-        Seq(
-          StructType(
-            aggregateFunc.aggBufferAttributes
-              .map(attribute => StructField("", attribute.dataType))
-              .toArray))
-      case _ =>
-        aggregateFunc.aggBufferAttributes.map(_.dataType)
-    }
-  }
-
   /**
    * Create an scalar function for the input aggregate function.
    * @param args:
@@ -887,7 +801,7 @@ object VeloxAggregateFunctionsBuilder {
       functionMap,
       ConverterUtils.makeFuncName(
         substraitAggFuncName,
-        getInputTypes(aggregateFunc, forMergeCompanion),
+        VeloxIntermediateData.getInputTypes(aggregateFunc, forMergeCompanion),
         FunctionConfig.REQ))
   }
 }
diff --git a/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala b/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala
new file mode 100644
index 000000000000..5eef2695cf52
--- /dev/null
+++ b/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.glutenproject.utils
+
+import io.glutenproject.expression.ConverterUtils
+import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
+
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.types._
+
+import scala.collection.JavaConverters._
+
+object VeloxIntermediateData {
+  // Agg functions with inconsistent ordering of intermediate data between Velox and Spark.
+  // Corr
+  val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
+  // CovPopulation, CovSample
+  val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg")
+
+  // Agg functions with inconsistent types of intermediate data between Velox and Spark.
+  // StddevSamp, StddevPop, VarianceSamp, VariancePop
+  val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
+  // CovPopulation, CovSample
+  val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType)
+  // Corr
+  val veloxCorrIntermediateTypes: Seq[DataType] =
+    Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType)
+
+  /**
+   * Return the intermediate columns order of Velox aggregation functions, with special matching
+   * required for some aggregation functions where the intermediate columns order are inconsistent
+   * with Spark.
+   * @param aggFunc
+   *   Spark aggregation function
+   * @return
+   *   the intermediate columns order of Velox aggregation functions
+   */
+  def veloxIntermediateDataOrder(aggFunc: AggregateFunction): Seq[String] = {
+    aggFunc match {
+      case _: Corr =>
+        veloxCorrIntermediateDataOrder
+      case _: CovPopulation | _: CovSample =>
+        veloxCovarIntermediateDataOrder
+      case _ =>
+        aggFunc.aggBufferAttributes.map(_.name)
+    }
+  }
+
+  /**
+   * Get the compatible input types for a Velox aggregate function.
+   *
+   * @param aggregateFunc
+   *   The input aggregate function.
+   * @param forMergeCompanion
+   *   Whether this is a special case to solve mixed aggregation phases.
+   * @return
+   *   The input types of a Velox aggregate function.
+   */
+  def getInputTypes(aggregateFunc: AggregateFunction, forMergeCompanion: Boolean): Seq[DataType] = {
+    if (!forMergeCompanion) {
+      return aggregateFunc.children.map(_.dataType)
+    }
+    aggregateFunc match {
+      case _ @Type(veloxDataTypes: Seq[DataType]) =>
+        Seq(StructType(veloxDataTypes.map(StructField("", _)).toArray))
+      case _ =>
+        // Not use StructType for single column agg intermediate data
+        aggregateFunc.aggBufferAttributes.map(_.dataType)
+    }
+  }
+
+  /**
+   * Return the intermediate type node of a partial aggregation in Velox.
+   *
+   * @param aggFunc
+   *   Spark aggregation function.
+   * @return
+   *   The type of partial outputs.
+   */
+  def getIntermediateTypeNode(aggFunc: AggregateFunction): TypeNode = {
+    val structTypeNodes =
+      aggFunc match {
+        case _ @Type(dataTypes: Seq[DataType]) =>
+          dataTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
+        case _ =>
+          throw new UnsupportedOperationException("Can not get velox intermediate types.")
+      }
+    TypeBuilder.makeStruct(false, structTypeNodes.asJava)
+  }
+
+  private object Type {
+
+    /**
+     * Return the intermediate types of Velox agg functions, with special matching required for some
+     * aggregation functions where the intermediate results are inconsistent with Spark. Only return
+     * if the intermediate result has multiple columns.
+     * @param aggFunc
+     *   Spark aggregation function
+     * @return
+     *   the intermediate types of Velox aggregation functions.
+     */
+    def unapply(aggFunc: AggregateFunction): Option[Seq[DataType]] = {
+      aggFunc match {
+        case _: Corr =>
+          Some(veloxCorrIntermediateTypes)
+        case _: Covariance =>
+          Some(veloxCovarIntermediateTypes)
+        case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
+          Some(veloxVarianceIntermediateTypes)
+        case _ if aggFunc.aggBufferAttributes.size > 1 =>
+          Some(aggFunc.aggBufferAttributes.map(_.dataType))
+        case _ => None
+      }
+    }
+  }
+}
diff --git a/gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala b/gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala
deleted file mode 100644
index e35473835d78..000000000000
--- a/gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package io.glutenproject.utils
-
-import org.apache.spark.sql.catalyst.expressions.aggregate.Average
-import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType}
-import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE}
-
-import scala.math.min
-
-object GlutenDecimalUtil {
-  object Fixed {
-    def unapply(t: DecimalType): Option[(Int, Int)] = Some((t.precision, t.scale))
-  }
-
-  def bounded(precision: Int, scale: Int): DecimalType = {
-    DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
-  }
-
-  def getAvgSumDataType(avg: Average): DataType = avg.dataType match {
-    // avg.dataType is Decimal(p + 4, s + 4) and sumType is Decimal(p + 10, s)
-    // we need to get sumType, so p = p - 4 + 10 and s = s - 4
-    case _ @GlutenDecimalUtil.Fixed(p, s) => GlutenDecimalUtil.bounded(p - 4 + 10, s - 4)
-    case _ => DoubleType
-  }
-}