From d2bf5fa68407121317a142521727319aa1883002 Mon Sep 17 00:00:00 2001 From: "joey.ljy" Date: Wed, 15 Nov 2023 15:09:32 +0800 Subject: [PATCH] Introduce VeloxIntermediateData to adjust type and order --- .../HashAggregateExecTransformer.scala | 126 +++-------------- .../utils/VeloxIntermediateData.scala | 131 ++++++++++++++++++ .../utils/GlutenDecimalUtil.scala | 40 ------ 3 files changed, 151 insertions(+), 146 deletions(-) create mode 100644 backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala delete mode 100644 gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index 1b8e10893a891..731b5f75a3db5 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -16,7 +16,6 @@ */ package io.glutenproject.execution -import io.glutenproject.execution.VeloxAggregateFunctionsBuilder._ import io.glutenproject.expression._ import io.glutenproject.expression.ConverterUtils.FunctionConfig import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} @@ -24,7 +23,7 @@ import io.glutenproject.substrait.{AggregationParams, SubstraitContext} import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode} import io.glutenproject.substrait.extensions.ExtensionBuilder import io.glutenproject.substrait.rel.{RelBuilder, RelNode} -import io.glutenproject.utils.GlutenDecimalUtil +import io.glutenproject.utils.VeloxIntermediateData import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -80,27 +79,18 @@ case class HashAggregateExecTransformer( * @return * extracting needed or not. */ - def extractStructNeeded(): Boolean = { - for (expr <- aggregateExpressions) { - val aggregateFunction = expr.aggregateFunction - aggregateFunction match { - case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp | - _: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy => - expr.mode match { - case Partial | PartialMerge => - return true - case _ => - } - case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => - expr.mode match { - case Partial | PartialMerge => - return true - case _ => - } - case _ => - } + private def extractStructNeeded(): Boolean = { + aggregateExpressions.exists { + expr => + expr.aggregateFunction match { + case aggFunc if aggFunc.aggBufferAttributes.size > 1 => + expr.mode match { + case Partial | PartialMerge => true + case _ => false + } + case _ => false + } } - false } /** @@ -209,42 +199,7 @@ case class HashAggregateExecTransformer( } } - /** - * Return the intermediate type node of a partial aggregation in Velox. - * @param aggregateFunction - * The aggregation function. - * @return - * The type of partial outputs. - */ - private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = { - val structTypeNodes = aggregateFunction match { - case avg: Average => - ConverterUtils.getTypeNode(GlutenDecimalUtil.getAvgSumDataType(avg), nullable = true) :: - ConverterUtils.getTypeNode(LongType, nullable = true) :: Nil - case first: First => - ConverterUtils.getTypeNode(first.dataType, nullable = true) :: - ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil - case last: Last => - ConverterUtils.getTypeNode(last.dataType, nullable = true) :: - ConverterUtils.getTypeNode(BooleanType, nullable = true) :: Nil - case maxMinBy: MaxMinBy => - ConverterUtils.getTypeNode(maxMinBy.valueExpr.dataType, nullable = true) :: - ConverterUtils.getTypeNode(maxMinBy.orderingExpr.dataType, nullable = true) :: Nil - case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => - // Use struct type to represent Velox Row(BIGINT, DOUBLE, DOUBLE). - veloxVarianceIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) - case _: Corr => - veloxCorrIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) - case _: CovPopulation | _: CovSample => - veloxCovarIntermediateTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) - case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => - ConverterUtils.getTypeNode(sum.dataType, nullable = true) :: - ConverterUtils.getTypeNode(BooleanType, nullable = false) :: Nil - case other => - throw new UnsupportedOperationException(s"$other is not supported.") - } - TypeBuilder.makeStruct(false, structTypeNodes.asJava) - } + override protected def modeToKeyWord(aggregateMode: AggregateMode): String = { super.modeToKeyWord(if (mixedPartialAndMerge) Partial else aggregateMode) @@ -268,7 +223,7 @@ case class HashAggregateExecTransformer( VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), childrenNodeList, modeKeyWord, - getIntermediateTypeNode(aggregateFunction)) + VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)) aggregateNodeList.add(partialNode) case PartialMerge => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( @@ -276,7 +231,7 @@ case class HashAggregateExecTransformer( .create(args, aggregateFunction, mixedPartialAndMerge), childrenNodeList, modeKeyWord, - getIntermediateTypeNode(aggregateFunction) + VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction) ) aggregateNodeList.add(aggFunctionNode) case Final => @@ -356,7 +311,7 @@ case class HashAggregateExecTransformer( _: VariancePop | _: Corr | _: CovPopulation | _: CovSample | _: MaxMinBy => expression.mode match { case Partial | PartialMerge => - typeNodeList.add(getIntermediateTypeNode(aggregateFunction)) + typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)) case Final => typeNodeList.add( ConverterUtils @@ -367,7 +322,7 @@ case class HashAggregateExecTransformer( case sum: Sum if sum.dataType.isInstanceOf[DecimalType] => expression.mode match { case Partial | PartialMerge => - typeNodeList.add(getIntermediateTypeNode(aggregateFunction)) + typeNodeList.add(VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)) case Final => typeNodeList.add( ConverterUtils @@ -547,7 +502,7 @@ case class HashAggregateExecTransformer( // Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk] val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name) val veloxInputOrder = - VeloxAggregateFunctionsBuilder.veloxCorrIntermediateDataOrder.map( + VeloxIntermediateData.veloxCorrIntermediateDataOrder.map( name => sparkCorrOutputAttr.indexOf(name)) for (order <- veloxInputOrder) { val attr = functionInputAttributes(order) @@ -590,7 +545,7 @@ case class HashAggregateExecTransformer( // Spark's Covar order is [n, xAvg, yAvg, ck] val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name) val veloxInputOrder = - VeloxAggregateFunctionsBuilder.veloxCovarIntermediateDataOrder.map( + VeloxIntermediateData.veloxCovarIntermediateDataOrder.map( name => sparkCorrOutputAttr.indexOf(name)) for (order <- veloxInputOrder) { val attr = functionInputAttributes(order) @@ -810,47 +765,6 @@ case class HashAggregateExecTransformer( /** An aggregation function builder specifically used by Velox backend. */ object VeloxAggregateFunctionsBuilder { - val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg") - val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg") - - val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType) - val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType) - val veloxCorrIntermediateTypes: Seq[DataType] = - Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType) - - /** - * Get the compatible input types for a Velox aggregate function. - * @param aggregateFunc: - * the input aggreagate function. - * @param forMergeCompanion: - * whether this is a special case to solve mixed aggregation phases. - * @return - * the input types of a Velox aggregate function. - */ - private def getInputTypes( - aggregateFunc: AggregateFunction, - forMergeCompanion: Boolean): Seq[DataType] = { - if (!forMergeCompanion) { - return aggregateFunc.children.map(_.dataType) - } - aggregateFunc match { - case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => - Seq(StructType(veloxVarianceIntermediateTypes.map(StructField("", _)).toArray)) - case _: CovPopulation | _: CovSample => - Seq(StructType(veloxCovarIntermediateTypes.map(StructField("", _)).toArray)) - case _: Corr => - Seq(StructType(veloxCorrIntermediateTypes.map(StructField("", _)).toArray)) - case aggFunc if aggFunc.aggBufferAttributes.size > 1 => - Seq( - StructType( - aggregateFunc.aggBufferAttributes - .map(attribute => StructField("", attribute.dataType)) - .toArray)) - case _ => - aggregateFunc.aggBufferAttributes.map(_.dataType) - } - } - /** * Create an scalar function for the input aggregate function. * @param args: @@ -887,7 +801,7 @@ object VeloxAggregateFunctionsBuilder { functionMap, ConverterUtils.makeFuncName( substraitAggFuncName, - getInputTypes(aggregateFunc, forMergeCompanion), + VeloxIntermediateData.getInputTypes(aggregateFunc, forMergeCompanion), FunctionConfig.REQ)) } } diff --git a/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala b/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala new file mode 100644 index 0000000000000..d9ca647603343 --- /dev/null +++ b/backends-velox/src/main/scala/io/glutenproject/utils/VeloxIntermediateData.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.glutenproject.utils + +import io.glutenproject.expression.ConverterUtils +import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} + +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types._ + +import scala.collection.JavaConverters._ + +object VeloxIntermediateData { + // Agg functions with inconsistent ordering of intermediate data between Velox and Spark. + // Corr + val veloxCorrIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xMk", "yMk", "xAvg", "yAvg") + // CovPopulation, CovSample + val veloxCovarIntermediateDataOrder: Seq[String] = Seq("ck", "n", "xAvg", "yAvg") + + // Agg functions with inconsistent types of intermediate data between Velox and Spark. + // StddevSamp, StddevPop, VarianceSamp, VariancePop + val veloxVarianceIntermediateTypes: Seq[DataType] = Seq(LongType, DoubleType, DoubleType) + // CovPopulation, CovSample + val veloxCovarIntermediateTypes: Seq[DataType] = Seq(DoubleType, LongType, DoubleType, DoubleType) + // Corr + val veloxCorrIntermediateTypes: Seq[DataType] = + Seq(DoubleType, LongType, DoubleType, DoubleType, DoubleType, DoubleType) + + /** + * Return the intermediate columns order of Velox aggregation functions, with special matching + * required for some aggregation functions where the intermediate columns order are inconsistent + * with Spark. + * @param aggFunc + * : Spark aggregation function + * @return + * the intermediate columns order of Velox aggregation functions + */ + def veloxIntermediateDataOrder(aggFunc: AggregateFunction): Seq[String] = { + aggFunc match { + case _: Corr => + veloxCorrIntermediateDataOrder + case _: CovPopulation | _: CovSample => + veloxCovarIntermediateDataOrder + case _ => + aggFunc.aggBufferAttributes.map(_.name) + } + } + + /** + * Get the compatible input types for a Velox aggregate function. + * + * @param aggregateFunc + * The input aggregate function. + * @param forMergeCompanion + * Whether this is a special case to solve mixed aggregation phases. + * @return + * The input types of a Velox aggregate function. + */ + def getInputTypes(aggregateFunc: AggregateFunction, forMergeCompanion: Boolean): Seq[DataType] = { + if (!forMergeCompanion) { + return aggregateFunc.children.map(_.dataType) + } + aggregateFunc match { + case _ @Type(veloxDataTypes: Seq[DataType]) => + Seq(StructType(veloxDataTypes.map(StructField("", _)).toArray)) + case _ => + // Not use StructType for single column agg intermediate data + aggregateFunc.aggBufferAttributes.map(_.dataType) + } + } + + /** + * Return the intermediate type node of a partial aggregation in Velox. + * + * @param aggFunc + * Spark aggregation function. + * @return + * The type of partial outputs. + */ + def getIntermediateTypeNode(aggFunc: AggregateFunction): TypeNode = { + val structTypeNodes = + aggFunc match { + case _ @Type(dataTypes: Seq[DataType]) => + dataTypes.map(ConverterUtils.getTypeNode(_, nullable = false)) + case _ => + throw new UnsupportedOperationException("Can not get velox intermediate types.") + } + TypeBuilder.makeStruct(false, structTypeNodes.asJava) + } + + private object Type { + + /** + * Return the intermediate types of Velox agg functions, with special matching required for some + * aggregation functions where the intermediate results are inconsistent with Spark. Only return + * if the intermediate result has multiple columns. + * @param aggFunc + * Spark aggregation function + * @return + * the intermediate types of Velox aggregation functions. + */ + def unapply(aggFunc: AggregateFunction): Option[Seq[DataType]] = { + aggFunc match { + case _: Corr => + Some(veloxCorrIntermediateTypes) + case _: Covariance => + Some(veloxCovarIntermediateTypes) + case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => + Some(veloxVarianceIntermediateTypes) + case _ if aggFunc.aggBufferAttributes.size > 1 => + Some(aggFunc.aggBufferAttributes.map(_.dataType)) + case _ => None + } + } + } +} diff --git a/gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala b/gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala deleted file mode 100644 index e35473835d78e..0000000000000 --- a/gluten-core/src/main/scala/io/glutenproject/utils/GlutenDecimalUtil.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.glutenproject.utils - -import org.apache.spark.sql.catalyst.expressions.aggregate.Average -import org.apache.spark.sql.types.{DataType, DecimalType, DoubleType} -import org.apache.spark.sql.types.DecimalType.{MAX_PRECISION, MAX_SCALE} - -import scala.math.min - -object GlutenDecimalUtil { - object Fixed { - def unapply(t: DecimalType): Option[(Int, Int)] = Some((t.precision, t.scale)) - } - - def bounded(precision: Int, scale: Int): DecimalType = { - DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) - } - - def getAvgSumDataType(avg: Average): DataType = avg.dataType match { - // avg.dataType is Decimal(p + 4, s + 4) and sumType is Decimal(p + 10, s) - // we need to get sumType, so p = p - 4 + 10 and s = s - 4 - case _ @GlutenDecimalUtil.Fixed(p, s) => GlutenDecimalUtil.bounded(p - 4 + 10, s - 4) - case _ => DoubleType - } -}