Skip to content

Commit

Permalink
Introduce VeloxIntermediateData to adjust type and order
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Nov 16, 2023
1 parent 31e354f commit d2bf5fa
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
*/
package io.glutenproject.execution

import io.glutenproject.execution.VeloxAggregateFunctionsBuilder._
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.GlutenDecimalUtil
import io.glutenproject.utils.VeloxIntermediateData

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -80,27 +79,18 @@ case class HashAggregateExecTransformer(
* @return
* extracting needed or not.
*/
def extractStructNeeded(): Boolean = {
for (expr <- aggregateExpressions) {
val aggregateFunction = expr.aggregateFunction
aggregateFunction match {
case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp |
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expr.mode match {
case Partial | PartialMerge =>
return true
case _ =>
}
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expr.mode match {
case Partial | PartialMerge =>
return true
case _ =>
}
case _ =>
}
private def extractStructNeeded(): Boolean = {
aggregateExpressions.exists {
expr =>
expr.aggregateFunction match {
case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
expr.mode match {
case Partial | PartialMerge => true
case _ => false
}
case _ => false
}
}
false
}

/**
Expand Down Expand Up @@ -209,42 +199,7 @@ case class HashAggregateExecTransformer(
}
}

/**
* Return the intermediate type node of a partial aggregation in Velox.
* @param aggregateFunction
* The aggregation function.
* @return
* The type of partial outputs.
*/
private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = {
val structTypeNodes = aggregateFunction match {
case avg: Average =>
ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true) ::
ConverterUtils.getTypeNode(LongType, nullable = true) :: Nil
case first: First =>
ConverterUtils.getTypeNode(first.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
case last: Last =>
ConverterUtils.getTypeNode(last.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil
case maxMinBy: MaxMinBy =>
ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true) ::
ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true) :: Nil
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
// Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE).
veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: Corr =>
veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _: CovPopulation | _: CovSample =>
veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
ConverterUtils.getTypeNode(sum.dataType, nullable = true) ::
ConverterUtils.getTypeNode(BooleanType, nullable = false) :: Nil
case other =>
throw new UnsupportedOperationException(s"$other is not supported.")
}
TypeBuilder.makeStruct(false, structTypeNodes.asJava)
}


override protected def modeToKeyWord(aggregateMode: AggregateMode): String = {
super.modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode)
Expand All @@ -268,15 +223,15 @@ case class HashAggregateExecTransformer(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction))
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
aggregateNodeList.add(partialNode)
case PartialMerge =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder
.create(args, aggregateFunction, mixedPartialAndMerge),
childrenNodeList,
modeKeyWord,
getIntermediateTypeNode(aggregateFunction)
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
)
aggregateNodeList.add(aggFunctionNode)
case Final =>
Expand Down Expand Up @@ -356,7 +311,7 @@ case class HashAggregateExecTransformer(
_: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils
Expand All @@ -367,7 +322,7 @@ case class HashAggregateExecTransformer(
case sum: Sum if sum.dataType.isInstanceOf[DecimalType] =>
expression.mode match {
case Partial | PartialMerge =>
typeNodeList.add(getIntermediateTypeNode(aggregateFunction))
typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction))
case Final =>
typeNodeList.add(
ConverterUtils
Expand Down Expand Up @@ -547,7 +502,7 @@ case class HashAggregateExecTransformer(
// Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxAggregateFunctionsBuilder.veloxCorrIntermediateDataOrder.map(
VeloxIntermediateData.veloxCorrIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
Expand Down Expand Up @@ -590,7 +545,7 @@ case class HashAggregateExecTransformer(
// Spark's Covar order is [n, xAvg, yAvg, ck]
val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name)
val veloxInputOrder =
VeloxAggregateFunctionsBuilder.veloxCovarIntermediateDataOrder.map(
VeloxIntermediateData.veloxCovarIntermediateDataOrder.map(
name => sparkCorrOutputAttr.indexOf(name))
for (order <- veloxInputOrder) {
val attr = functionInputAttributes(order)
Expand Down Expand Up @@ -810,47 +765,6 @@ case class HashAggregateExecTransformer(
/** An aggregation function builder specifically used by Velox backend. */
object VeloxAggregateFunctionsBuilder {

val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg")

val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType)
val veloxCorrIntermediateTypes: Seq[DataType] =
Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType)

/**
* Get the compatible input types for a Velox aggregate function.
* @param aggregateFunc:
* the input aggreagate function.
* @param forMergeCompanion:
* whether this is a special case to solve mixed aggregation phases.
* @return
* the input types of a Velox aggregate function.
*/
private def getInputTypes(
aggregateFunc: AggregateFunction,
forMergeCompanion: Boolean): Seq[DataType] = {
if (!forMergeCompanion) {
return aggregateFunc.children.map(_.dataType)
}
aggregateFunc match {
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray))
case _: CovPopulation | _: CovSample =>
Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray))
case _: Corr =>
Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray))
case aggFunc if aggFunc.aggBufferAttributes.size > 1 =>
Seq(
StructType(
aggregateFunc.aggBufferAttributes
.map(attribute => StructField("", attribute.dataType))
.toArray))
case _ =>
aggregateFunc.aggBufferAttributes.map(_.dataType)
}
}

/**
* Create an scalar function for the input aggregate function.
* @param args:
Expand Down Expand Up @@ -887,7 +801,7 @@ object VeloxAggregateFunctionsBuilder {
functionMap,
ConverterUtils.makeFuncName(
substraitAggFuncName,
getInputTypes(aggregateFunc, forMergeCompanion),
VeloxIntermediateData.getInputTypes(aggregateFunc, forMergeCompanion),
FunctionConfig.REQ))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.glutenproject.utils

import io.glutenproject.expression.ConverterUtils
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}

import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.types._

import scala.collection.JavaConverters._

object VeloxIntermediateData {
// Agg functions with inconsistent ordering of intermediate data between Velox and Spark.
// Corr
val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg")
// CovPopulation, CovSample
val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg")

// Agg functions with inconsistent types of intermediate data between Velox and Spark.
// StddevSamp, StddevPop, VarianceSamp, VariancePop
val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType)
// CovPopulation, CovSample
val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType)
// Corr
val veloxCorrIntermediateTypes: Seq[DataType] =
Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType)

/**
* Return the intermediate columns order of Velox aggregation functions, with special matching
* required for some aggregation functions where the intermediate columns order are inconsistent
* with Spark.
* @param aggFunc
* : Spark aggregation function
* @return
* the intermediate columns order of Velox aggregation functions
*/
def veloxIntermediateDataOrder(aggFunc: AggregateFunction): Seq[String] = {
aggFunc match {
case _: Corr =>
veloxCorrIntermediateDataOrder
case _: CovPopulation | _: CovSample =>
veloxCovarIntermediateDataOrder
case _ =>
aggFunc.aggBufferAttributes.map(_.name)
}
}

/**
* Get the compatible input types for a Velox aggregate function.
*
* @param aggregateFunc
* The input aggregate function.
* @param forMergeCompanion
* Whether this is a special case to solve mixed aggregation phases.
* @return
* The input types of a Velox aggregate function.
*/
def getInputTypes(aggregateFunc: AggregateFunction, forMergeCompanion: Boolean): Seq[DataType] = {
if (!forMergeCompanion) {
return aggregateFunc.children.map(_.dataType)
}
aggregateFunc match {
case _ @Type(veloxDataTypes: Seq[DataType]) =>
Seq(StructType(veloxDataTypes.map(StructField("", _)).toArray))
case _ =>
// Not use StructType for single column agg intermediate data
aggregateFunc.aggBufferAttributes.map(_.dataType)
}
}

/**
* Return the intermediate type node of a partial aggregation in Velox.
*
* @param aggFunc
* Spark aggregation function.
* @return
* The type of partial outputs.
*/
def getIntermediateTypeNode(aggFunc: AggregateFunction): TypeNode = {
val structTypeNodes =
aggFunc match {
case _ @Type(dataTypes: Seq[DataType]) =>
dataTypes.map(ConverterUtils.getTypeNode(_, nullable = false))
case _ =>
throw new UnsupportedOperationException("Can not get velox intermediate types.")
}
TypeBuilder.makeStruct(false, structTypeNodes.asJava)
}

private object Type {

/**
* Return the intermediate types of Velox agg functions, with special matching required for some
* aggregation functions where the intermediate results are inconsistent with Spark. Only return
* if the intermediate result has multiple columns.
* @param aggFunc
* Spark aggregation function
* @return
* the intermediate types of Velox aggregation functions.
*/
def unapply(aggFunc: AggregateFunction): Option[Seq[DataType]] = {
aggFunc match {
case _: Corr =>
Some(veloxCorrIntermediateTypes)
case _: Covariance =>
Some(veloxCovarIntermediateTypes)
case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop =>
Some(veloxVarianceIntermediateTypes)
case _ if aggFunc.aggBufferAttributes.size > 1 =>
Some(aggFunc.aggBufferAttributes.map(_.dataType))
case _ => None
}
}
}
}

This file was deleted.

0 comments on commit d2bf5fa

Please sign in to comment.