-
Notifications
You must be signed in to change notification settings - Fork 446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GLUTEN-3719][VL] Introduce VeloxIntermediateData to adjust velox agg func intermediate data #3721
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,15 +16,14 @@ | |
*/ | ||
package io.glutenproject.execution | ||
|
||
import io.glutenproject.execution.VeloxAggregateFunctionsBuilder._ | ||
import io.glutenproject.expression._ | ||
import io.glutenproject.expression.ConverterUtils.FunctionConfig | ||
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} | ||
import io.glutenproject.substrait.{AggregationParams, SubstraitContext} | ||
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode} | ||
import io.glutenproject.substrait.extensions.ExtensionBuilder | ||
import io.glutenproject.substrait.rel.{RelBuilder, RelNode} | ||
import io.glutenproject.utils.GlutenDecimalUtil | ||
import io.glutenproject.utils.VeloxIntermediateData | ||
|
||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
|
@@ -80,27 +79,18 @@ case class HashAggregateExecTransformer( | |
* @return | ||
* extracting needed or not. | ||
*/ | ||
def extractStructNeeded(): Boolean = { | ||
for (expr <- aggregateExpressions) { | ||
val aggregateFunction = expr.aggregateFunction | ||
aggregateFunction match { | ||
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp | | ||
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy => | ||
expr.mode match { | ||
case Partial | PartialMerge => | ||
return true | ||
case _ => | ||
} | ||
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => | ||
expr.mode match { | ||
case Partial | PartialMerge => | ||
return true | ||
case _ => | ||
} | ||
case _ => | ||
} | ||
private def extractStructNeeded(): Boolean = { | ||
aggregateExpressions.exists { | ||
expr => | ||
expr.aggregateFunction match { | ||
case aggFunc if aggFunc.aggBufferAttributes.size > 1 => | ||
expr.mode match { | ||
case Partial | PartialMerge => true | ||
case _ => false | ||
} | ||
case _ => false | ||
} | ||
} | ||
false | ||
} | ||
|
||
/** | ||
|
@@ -133,56 +123,29 @@ case class HashAggregateExecTransformer( | |
case _ => | ||
throw new UnsupportedOperationException(s"${expr.mode} not supported.") | ||
} | ||
val aggFunc = expr.aggregateFunction | ||
expr.aggregateFunction match { | ||
case _: Average | _: First | _: Last | _: MaxMinBy => | ||
// Select first and second aggregate buffer from Velox Struct. | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0)) | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1)) | ||
colIdx += 1 | ||
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => | ||
// Select count from Velox struct with count casted from LongType into DoubleType. | ||
expressionNodes.add( | ||
ExpressionBuilder | ||
.makeCast( | ||
ConverterUtils.getTypeNode(DoubleType, nullable = false), | ||
ExpressionBuilder.makeSelection(colIdx, 0), | ||
SQLConf.get.ansiEnabled)) | ||
// Select avg from Velox Struct. | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1)) | ||
// Select m2 from Velox Struct. | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2)) | ||
colIdx += 1 | ||
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => | ||
// Select sum from Velox Struct. | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0)) | ||
// Select isEmpty from Velox Struct. | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 1)) | ||
colIdx += 1 | ||
case _: Corr => | ||
// Select count from Velox struct with count casted from LongType into DoubleType. | ||
expressionNodes.add( | ||
ExpressionBuilder | ||
.makeCast( | ||
ConverterUtils.getTypeNode(DoubleType, nullable = false), | ||
ExpressionBuilder.makeSelection(colIdx, 1), | ||
SQLConf.get.ansiEnabled)) | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 4)) | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 5)) | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0)) | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2)) | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 3)) | ||
colIdx += 1 | ||
case _: CovPopulation | _: CovSample => | ||
// Select count from Velox struct with count casted from LongType into DoubleType. | ||
expressionNodes.add( | ||
ExpressionBuilder | ||
.makeCast( | ||
ConverterUtils.getTypeNode(DoubleType, nullable = false), | ||
ExpressionBuilder.makeSelection(colIdx, 1), | ||
SQLConf.get.ansiEnabled)) | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 2)) | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 3)) | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, 0)) | ||
case _ @VeloxIntermediateData.Type(veloxTypes: Seq[DataType]) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm confused here. Why can we match There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will use |
||
val (sparkOrders, sparkTypes) = | ||
aggFunc.aggBufferAttributes.map(attr => (attr.name, attr.dataType)).unzip | ||
val veloxOrders = VeloxIntermediateData.veloxIntermediateDataOrder(aggFunc) | ||
val adjustedOrders = sparkOrders.map(veloxOrders.indexOf(_)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it enough to decide the order based on name equality? E.g., if attr.name contains suffix of exprId, would it fail to match with the string in veloxOrders? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the column names in |
||
sparkTypes.zipWithIndex.foreach { | ||
case (sparkType, idx) => | ||
val veloxType = veloxTypes(adjustedOrders(idx)) | ||
if (veloxType != sparkType) { | ||
// Velox and Spark have different type, adding a cast expression | ||
expressionNodes.add( | ||
ExpressionBuilder | ||
.makeCast( | ||
ConverterUtils.getTypeNode(sparkType, nullable = false), | ||
ExpressionBuilder.makeSelection(colIdx, adjustedOrders(idx)), | ||
SQLConf.get.ansiEnabled)) | ||
} else { | ||
// Velox and Spark have the same type | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx, adjustedOrders(idx))) | ||
} | ||
} | ||
colIdx += 1 | ||
case _ => | ||
expressionNodes.add(ExpressionBuilder.makeSelection(colIdx)) | ||
|
@@ -209,43 +172,6 @@ case class HashAggregateExecTransformer( | |
} | ||
} | ||
|
||
/** | ||
* Return the intermediate type node of a partial aggregation in Velox. | ||
* @param aggregateFunction | ||
* The aggregation function. | ||
* @return | ||
* The type of partial outputs. | ||
*/ | ||
private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = { | ||
val structTypeNodes = aggregateFunction match { | ||
case avg: Average => | ||
ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true) :: | ||
ConverterUtils.getTypeNode(LongType, nullable = true) :: Nil | ||
case first: First => | ||
ConverterUtils.getTypeNode(first.dataType, nullable = true) :: | ||
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil | ||
case last: Last => | ||
ConverterUtils.getTypeNode(last.dataType, nullable = true) :: | ||
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil | ||
case maxMinBy: MaxMinBy => | ||
ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true) :: | ||
ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true) :: Nil | ||
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => | ||
// Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE). | ||
veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) | ||
case _: Corr => | ||
veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) | ||
case _: CovPopulation | _: CovSample => | ||
veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) | ||
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => | ||
ConverterUtils.getTypeNode(sum.dataType, nullable = true) :: | ||
ConverterUtils.getTypeNode(BooleanType, nullable = false) :: Nil | ||
case other => | ||
throw new UnsupportedOperationException(s"$other is not supported.") | ||
} | ||
TypeBuilder.makeStruct(false, structTypeNodes.asJava) | ||
} | ||
|
||
override protected def modeToKeyWord(aggregateMode: AggregateMode): String = { | ||
super.modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode) | ||
} | ||
|
@@ -268,15 +194,16 @@ case class HashAggregateExecTransformer( | |
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), | ||
childrenNodeList, | ||
modeKeyWord, | ||
getIntermediateTypeNode(aggregateFunction)) | ||
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction) | ||
) | ||
aggregateNodeList.add(partialNode) | ||
case PartialMerge => | ||
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( | ||
VeloxAggregateFunctionsBuilder | ||
.create(args, aggregateFunction, mixedPartialAndMerge), | ||
childrenNodeList, | ||
modeKeyWord, | ||
getIntermediateTypeNode(aggregateFunction) | ||
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction) | ||
) | ||
aggregateNodeList.add(aggFunctionNode) | ||
case Final => | ||
|
@@ -356,7 +283,7 @@ case class HashAggregateExecTransformer( | |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy => | ||
expression.mode match { | ||
case Partial | PartialMerge => | ||
typeNodeList.add(getIntermediateTypeNode(aggregateFunction)) | ||
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)) | ||
case Final => | ||
typeNodeList.add( | ||
ConverterUtils | ||
|
@@ -367,7 +294,7 @@ case class HashAggregateExecTransformer( | |
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => | ||
expression.mode match { | ||
case Partial | PartialMerge => | ||
typeNodeList.add(getIntermediateTypeNode(aggregateFunction)) | ||
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)) | ||
case Final => | ||
typeNodeList.add( | ||
ConverterUtils | ||
|
@@ -547,7 +474,7 @@ case class HashAggregateExecTransformer( | |
// Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk] | ||
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name) | ||
val veloxInputOrder = | ||
VeloxAggregateFunctionsBuilder.veloxCorrIntermediateDataOrder.map( | ||
VeloxIntermediateData.veloxCorrIntermediateDataOrder.map( | ||
name => sparkCorrOutputAttr.indexOf(name)) | ||
for (order <- veloxInputOrder) { | ||
val attr = functionInputAttributes(order) | ||
|
@@ -590,7 +517,7 @@ case class HashAggregateExecTransformer( | |
// Spark's Covar order is [n, xAvg, yAvg, ck] | ||
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name) | ||
val veloxInputOrder = | ||
VeloxAggregateFunctionsBuilder.veloxCovarIntermediateDataOrder.map( | ||
VeloxIntermediateData.veloxCovarIntermediateDataOrder.map( | ||
name => sparkCorrOutputAttr.indexOf(name)) | ||
for (order <- veloxInputOrder) { | ||
val attr = functionInputAttributes(order) | ||
|
@@ -810,47 +737,6 @@ case class HashAggregateExecTransformer( | |
/** An aggregation function builder specifically used by Velox backend. */ | ||
object VeloxAggregateFunctionsBuilder { | ||
|
||
val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg") | ||
val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg") | ||
|
||
val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType) | ||
val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType) | ||
val veloxCorrIntermediateTypes: Seq[DataType] = | ||
Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType) | ||
|
||
/** | ||
* Get the compatible input types for a Velox aggregate function. | ||
* @param aggregateFunc: | ||
* the input aggreagate function. | ||
* @param forMergeCompanion: | ||
* whether this is a special case to solve mixed aggregation phases. | ||
* @return | ||
* the input types of a Velox aggregate function. | ||
*/ | ||
private def getInputTypes( | ||
aggregateFunc: AggregateFunction, | ||
forMergeCompanion: Boolean): Seq[DataType] = { | ||
if (!forMergeCompanion) { | ||
return aggregateFunc.children.map(_.dataType) | ||
} | ||
aggregateFunc match { | ||
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => | ||
Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray)) | ||
case _: CovPopulation | _: CovSample => | ||
Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray)) | ||
case _: Corr => | ||
Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray)) | ||
case aggFunc if aggFunc.aggBufferAttributes.size > 1 => | ||
Seq( | ||
StructType( | ||
aggregateFunc.aggBufferAttributes | ||
.map(attribute => StructField("", attribute.dataType)) | ||
.toArray)) | ||
case _ => | ||
aggregateFunc.aggBufferAttributes.map(_.dataType) | ||
} | ||
} | ||
|
||
/** | ||
* Create an scalar function for the input aggregate function. | ||
* @param args: | ||
|
@@ -887,7 +773,7 @@ object VeloxAggregateFunctionsBuilder { | |
functionMap, | ||
ConverterUtils.makeFuncName( | ||
substraitAggFuncName, | ||
getInputTypes(aggregateFunc, forMergeCompanion), | ||
VeloxIntermediateData.getInputTypes(aggregateFunc, forMergeCompanion), | ||
FunctionConfig.REQ)) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use aggFunc defined in the previous line.