Skip to content

Commit

Permalink
[GLUTEN-3719][VL] Introduce VeloxIntermediateData to adjust agg func …
Browse files Browse the repository at this point in the history
…intermediate types (#3721)
  • Loading branch information
liujiayi771 authored Nov 21, 2023
1 parent 018da4c commit f29077e
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
*/
package io.glutenproject.execution

import io.glutenproject.execution.VeloxAggregateFunctionsBuilder._
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.GlutenDecimalUtil
import io.glutenproject.utils.VeloxIntermediateData

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -80,27 +79,18 @@ case class HashAggregateExecTransformer(
* @return
* extracting needed or not.
*/
def extractStructNeeded(): Boolean = {
for (expr <- aggregateExpressions) {
val aggregateFunction = expr.aggregateFunction
aggregateFunction match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expr.mode match {
case Partial | PartialMerge =>
return true
case _ =>
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expr.mode match {
case Partial | PartialMerge =>
return true
case _ =>
}
case _ =>
}
private def extractStructNeeded(): Boolean = {
aggregateExpressions.exists {
expr =>
expr.aggregateFunction match {
case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
expr.mode match {
case Partial | PartialMerge => true
case _ => false
}
case _ => false
}
}
false
}

/**
Expand Down Expand Up @@ -133,56 +123,29 @@ case class HashAggregateExecTransformer(
case _ =>
throw new UnsupportedOperationException(s"${expr.mode} not supported.")
}
val aggFunc = expr.aggregateFunction
expr.aggregateFunction match {
case _: Average | _: First | _: Last | _: MaxMinBy =>
// Select first and second aggregate buffer from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
colIdx += 1
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Select count from Velox struct with count casted from LongType into DoubleType.
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(DoubleType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, 0),
SQLConf.get.ansiEnabled))
// Select avg from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
// Select m2 from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2))
colIdx += 1
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
// Select sum from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
// Select isEmpty from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
colIdx += 1
case _: Corr =>
// Select count from Velox struct with count casted from LongType into DoubleType.
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(DoubleType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, 1),
SQLConf.get.ansiEnabled))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 4))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 5))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 3))
colIdx += 1
case _: CovPopulation | _: CovSample =>
// Select count from Velox struct with count casted from LongType into DoubleType.
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(DoubleType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, 1),
SQLConf.get.ansiEnabled))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 3))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
val adjustedOrders = sparkOrders.map(veloxOrders.indexOf(_))
sparkTypes.zipWithIndex.foreach {
case (sparkType, idx) =>
val veloxType = veloxTypes(adjustedOrders(idx))
if (veloxType != sparkType) {
// Velox and Spark have different type, adding a cast expression
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(sparkType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, adjustedOrders(idx)),
SQLConf.get.ansiEnabled))
} else {
// Velox and Spark have the same type
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, adjustedOrders(idx)))
}
}
colIdx += 1
case _ =>
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx))
Expand All @@ -209,43 +172,6 @@ case class HashAggregateExecTransformer(
}
}

/**
* Return the intermediate type node of a partial aggregation in Velox.
* @param aggregateFunction
* The aggregation function.
* @return
* The type of partial outputs.
*/
private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = {
val structTypeNodes = aggregateFunction match {
case avg: Average =>
ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true) ::
ConverterUtils.getTypeNode(LongType, nullable = true) :: Nil
case first: First =>
ConverterUtils.getTypeNode(first.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
case last: Last =>
ConverterUtils.getTypeNode(last.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
case maxMinBy: MaxMinBy =>
ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true) ::
ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true) :: Nil
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE).
veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: Corr =>
veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: CovPopulation | _: CovSample =>
veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
ConverterUtils.getTypeNode(sum.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = false) :: Nil
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
TypeBuilder.makeStruct(false, structTypeNodes.asJava)
}

override protected def modeToKeyWord(aggregateMode: AggregateMode): String = {
super.modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode)
}
Expand All @@ -268,15 +194,16 @@ case class HashAggregateExecTransformer(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction))
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
)
aggregateNodeList.add(partialNode)
case PartialMerge =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder
.create(args, aggregateFunction, mixedPartialAndMerge),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction)
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
)
aggregateNodeList.add(aggFunctionNode)
case Final =>
Expand Down Expand Up @@ -356,7 +283,7 @@ case class HashAggregateExecTransformer(
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils
Expand All @@ -367,7 +294,7 @@ case class HashAggregateExecTransformer(
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils
Expand Down Expand Up @@ -547,7 +474,7 @@ case class HashAggregateExecTransformer(
// Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxAggregateFunctionsBuilder.veloxCorrIntermediateDataOrder.map(
VeloxIntermediateData.veloxCorrIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
Expand Down Expand Up @@ -590,7 +517,7 @@ case class HashAggregateExecTransformer(
// Spark's Covar order is [n, xAvg, yAvg, ck]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxAggregateFunctionsBuilder.veloxCovarIntermediateDataOrder.map(
VeloxIntermediateData.veloxCovarIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
Expand Down Expand Up @@ -810,47 +737,6 @@ case class HashAggregateExecTransformer(
/** An aggregation function builder specifically used by Velox backend. */
object VeloxAggregateFunctionsBuilder {

val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg")

val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType)
val veloxCorrIntermediateTypes: Seq[DataType] =
Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType)

/**
* Get the compatible input types for a Velox aggregate function.
* @param aggregateFunc:
* the input aggreagate function.
* @param forMergeCompanion:
* whether this is a special case to solve mixed aggregation phases.
* @return
* the input types of a Velox aggregate function.
*/
private def getInputTypes(
aggregateFunc: AggregateFunction,
forMergeCompanion: Boolean): Seq[DataType] = {
if (!forMergeCompanion) {
return aggregateFunc.children.map(_.dataType)
}
aggregateFunc match {
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray))
case _: CovPopulation | _: CovSample =>
Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray))
case _: Corr =>
Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray))
case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
Seq(
StructType(
aggregateFunc.aggBufferAttributes
.map(attribute => StructField("", attribute.dataType))
.toArray))
case _ =>
aggregateFunc.aggBufferAttributes.map(_.dataType)
}
}

/**
* Create an scalar function for the input aggregate function.
* @param args:
Expand Down Expand Up @@ -887,7 +773,7 @@ object VeloxAggregateFunctionsBuilder {
functionMap,
ConverterUtils.makeFuncName(
substraitAggFuncName,
getInputTypes(aggregateFunc, forMergeCompanion),
VeloxIntermediateData.getInputTypes(aggregateFunc, forMergeCompanion),
FunctionConfig.REQ))
}
}
Loading

0 comments on commit f29077e

Please sign in to comment.