Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-3719][VL] Introduce VeloxIntermediateData to adjust velox agg func intermediate data #3721

Merged
merged 4 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
*/
package io.glutenproject.execution

import io.glutenproject.execution.VeloxAggregateFunctionsBuilder._
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.GlutenDecimalUtil
import io.glutenproject.utils.VeloxIntermediateData

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -80,27 +79,18 @@ case class HashAggregateExecTransformer(
* @return
* extracting needed or not.
*/
def extractStructNeeded(): Boolean = {
for (expr <- aggregateExpressions) {
val aggregateFunction = expr.aggregateFunction
aggregateFunction match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expr.mode match {
case Partial | PartialMerge =>
return true
case _ =>
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expr.mode match {
case Partial | PartialMerge =>
return true
case _ =>
}
case _ =>
}
private def extractStructNeeded(): Boolean = {
aggregateExpressions.exists {
expr =>
expr.aggregateFunction match {
case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
expr.mode match {
case Partial | PartialMerge => true
case _ => false
}
case _ => false
}
}
false
}

/**
Expand Down Expand Up @@ -133,56 +123,29 @@ case class HashAggregateExecTransformer(
case _ =>
throw new UnsupportedOperationException(s"${expr.mode} not supported.")
}
val aggFunc = expr.aggregateFunction
expr.aggregateFunction match {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use aggFunc defined in the previous line.

case _: Average | _: First | _: Last | _: MaxMinBy =>
// Select first and second aggregate buffer from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
colIdx += 1
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Select count from Velox struct with count casted from LongType into DoubleType.
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(DoubleType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, 0),
SQLConf.get.ansiEnabled))
// Select avg from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
// Select m2 from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2))
colIdx += 1
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
// Select sum from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
// Select isEmpty from Velox Struct.
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1))
colIdx += 1
case _: Corr =>
// Select count from Velox struct with count casted from LongType into DoubleType.
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(DoubleType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, 1),
SQLConf.get.ansiEnabled))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 4))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 5))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 3))
colIdx += 1
case _: CovPopulation | _: CovSample =>
// Select count from Velox struct with count casted from LongType into DoubleType.
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(DoubleType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, 1),
SQLConf.get.ansiEnabled))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 3))
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0))
case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused here. Why can we match aggregateFunction against VeloxIntermediateData.Type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will use VeloxIntermediateData.Type.unapply method to extract veloxTypes from aggFunc. This is equivalent to val veloxTypes = VeloxIntermediateData.Type.unapply(aggFunc)

val (sparkOrders, sparkTypes) =
aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip
val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc)
val adjustedOrders = sparkOrders.map(veloxOrders.indexOf(_))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it enough to decide the order based on name equality? E.g., if attr.name contains suffix of exprId, would it fail to match with the string in veloxOrders?

Copy link
Contributor Author

@liujiayi771 liujiayi771 Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparkTypes.zipWithIndex.foreach {
case (sparkType, idx) =>
val veloxType = veloxTypes(adjustedOrders(idx))
if (veloxType != sparkType) {
// Velox and Spark have different type, adding a cast expression
expressionNodes.add(
ExpressionBuilder
.makeCast(
ConverterUtils.getTypeNode(sparkType, nullable = false),
ExpressionBuilder.makeSelection(colIdx, adjustedOrders(idx)),
SQLConf.get.ansiEnabled))
} else {
// Velox and Spark have the same type
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, adjustedOrders(idx)))
}
}
colIdx += 1
case _ =>
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx))
Expand All @@ -209,43 +172,6 @@ case class HashAggregateExecTransformer(
}
}

/**
* Return the intermediate type node of a partial aggregation in Velox.
* @param aggregateFunction
* The aggregation function.
* @return
* The type of partial outputs.
*/
private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = {
val structTypeNodes = aggregateFunction match {
case avg: Average =>
ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true) ::
ConverterUtils.getTypeNode(LongType, nullable = true) :: Nil
case first: First =>
ConverterUtils.getTypeNode(first.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
case last: Last =>
ConverterUtils.getTypeNode(last.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
case maxMinBy: MaxMinBy =>
ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true) ::
ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true) :: Nil
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE).
veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: Corr =>
veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: CovPopulation | _: CovSample =>
veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
ConverterUtils.getTypeNode(sum.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = false) :: Nil
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
TypeBuilder.makeStruct(false, structTypeNodes.asJava)
}

override protected def modeToKeyWord(aggregateMode: AggregateMode): String = {
super.modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode)
}
Expand All @@ -268,15 +194,16 @@ case class HashAggregateExecTransformer(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction))
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
)
aggregateNodeList.add(partialNode)
case PartialMerge =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder
.create(args, aggregateFunction, mixedPartialAndMerge),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction)
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
)
aggregateNodeList.add(aggFunctionNode)
case Final =>
Expand Down Expand Up @@ -356,7 +283,7 @@ case class HashAggregateExecTransformer(
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils
Expand All @@ -367,7 +294,7 @@ case class HashAggregateExecTransformer(
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils
Expand Down Expand Up @@ -547,7 +474,7 @@ case class HashAggregateExecTransformer(
// Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxAggregateFunctionsBuilder.veloxCorrIntermediateDataOrder.map(
VeloxIntermediateData.veloxCorrIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
Expand Down Expand Up @@ -590,7 +517,7 @@ case class HashAggregateExecTransformer(
// Spark's Covar order is [n, xAvg, yAvg, ck]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxAggregateFunctionsBuilder.veloxCovarIntermediateDataOrder.map(
VeloxIntermediateData.veloxCovarIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
Expand Down Expand Up @@ -810,47 +737,6 @@ case class HashAggregateExecTransformer(
/** An aggregation function builder specifically used by Velox backend. */
object VeloxAggregateFunctionsBuilder {

val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg")

val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType)
val veloxCorrIntermediateTypes: Seq[DataType] =
Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType)

/**
* Get the compatible input types for a Velox aggregate function.
* @param aggregateFunc:
* the input aggreagate function.
* @param forMergeCompanion:
* whether this is a special case to solve mixed aggregation phases.
* @return
* the input types of a Velox aggregate function.
*/
private def getInputTypes(
aggregateFunc: AggregateFunction,
forMergeCompanion: Boolean): Seq[DataType] = {
if (!forMergeCompanion) {
return aggregateFunc.children.map(_.dataType)
}
aggregateFunc match {
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray))
case _: CovPopulation | _: CovSample =>
Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray))
case _: Corr =>
Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray))
case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
Seq(
StructType(
aggregateFunc.aggBufferAttributes
.map(attribute => StructField("", attribute.dataType))
.toArray))
case _ =>
aggregateFunc.aggBufferAttributes.map(_.dataType)
}
}

/**
* Create an scalar function for the input aggregate function.
* @param args:
Expand Down Expand Up @@ -887,7 +773,7 @@ object VeloxAggregateFunctionsBuilder {
functionMap,
ConverterUtils.makeFuncName(
substraitAggFuncName,
getInputTypes(aggregateFunc, forMergeCompanion),
VeloxIntermediateData.getInputTypes(aggregateFunc, forMergeCompanion),
FunctionConfig.REQ))
}
}
Loading
Loading