From d8248ea8ece0f256d5eda92d22d72ac0819b1596 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Fri, 5 Jan 2024 21:57:02 +0800 Subject: [PATCH] [VL] Add FlushableHashAggregateExecTransformer to map Velox's partial aggregation which supports flushing and abandoning (#4130) --- .../velox/SparkPlanExecApiImpl.scala | 5 +- .../HashAggregateExecTransformer.scala | 141 +++++++++++------- .../catalyst/FlushableHashAggregateRule.scala | 56 +++++++ .../execution/TestOperator.scala | 17 ++- cpp/velox/compute/WholeStageResultIterator.cc | 4 - .../substrait/SubstraitExtensionCollector.cc | 9 -- .../substrait/SubstraitExtensionCollector.h | 7 - cpp/velox/substrait/SubstraitToVeloxPlan.cc | 55 +++++-- cpp/velox/substrait/SubstraitToVeloxPlan.h | 10 ++ .../SubstraitToVeloxPlanValidator.cc | 50 +------ .../substrait/VeloxSubstraitSignature.cc | 16 +- cpp/velox/substrait/VeloxToSubstraitPlan.cc | 63 +++++++- .../tests/SubstraitExtensionCollectorTest.cc | 6 +- .../tests/VeloxSubstraitRoundTripTest.cc | 94 +++++++++++- .../tests/VeloxSubstraitSignatureTest.cc | 8 +- docs/get-started/Velox.md | 1 - ep/build-velox/src/get_velox.sh | 2 +- .../expression/AggregateFunctionNode.java | 2 +- .../HashAggregateExecBaseTransformer.scala | 13 +- .../WholeStageTransformerSuite.scala | 2 +- 20 files changed, 397 insertions(+), 164 deletions(-) create mode 100644 backends-velox/src/main/scala/org/apache/spark/sql/catalyst/FlushableHashAggregateRule.scala diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala index a9743435f373..a50401f48c2b 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala @@ -157,7 +157,7 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi { initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan): HashAggregateExecBaseTransformer = - HashAggregateExecTransformer( + RegularHashAggregateExecTransformer( requiredChildDistributionExpressions, groupingExpressions, aggregateExpressions, @@ -488,7 +488,8 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi { * * @return */ - override def genExtendedColumnarPreRules(): List[SparkSession => Rule[SparkPlan]] = List() + override def genExtendedColumnarPreRules(): List[SparkSession => Rule[SparkPlan]] = + List() /** * Generate extended columnar post-rules. diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index 92606ab56398..c476c5310786 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -38,7 +38,7 @@ import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -case class HashAggregateExecTransformer( +abstract class HashAggregateExecTransformer( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -69,10 +69,6 @@ case class HashAggregateExecTransformer( } } - override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = { - copy(child = newChild) - } - /** * Returns whether extracting subfield from struct is needed. True when the intermediate type of * Velox aggregation is a compound type. @@ -173,15 +169,16 @@ case class HashAggregateExecTransformer( } } - override protected def modeToKeyWord(aggregateMode: AggregateMode): String = { - super.modeToKeyWord(if (mixedPartialAndMerge) { - Partial - } else { - aggregateMode match { - case PartialMerge => Final - case _ => aggregateMode - } - }) + // Whether the output data allows to be just pre-aggregated rather than + // fully aggregated. If true, aggregation could flush its in memory + // aggregated data whenever is needed rather than waiting for all input + // to be read. + protected def allowFlush: Boolean + + override protected def formatExtOptimizationString(isStreaming: Boolean): String = { + val isStreamingStr = if (isStreaming) "1" else "0" + val allowFlushStr = if (allowFlush) "1" else "0" + s"isStreaming=$isStreamingStr\nallowFlush=$allowFlushStr\n" } // Create aggregate function node and add to list. @@ -191,15 +188,13 @@ case class HashAggregateExecTransformer( childrenNodeList: JList[ExpressionNode], aggregateMode: AggregateMode, aggregateNodeList: JList[AggregateFunctionNode]): Unit = { - // This is a special handling for PartialMerge in the execution of distinct. - // Use Partial phase instead for this aggregation. val modeKeyWord = modeToKeyWord(aggregateMode) def generateMergeCompanionNode(): Unit = { aggregateMode match { case Partial => val partialNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode), childrenNodeList, modeKeyWord, VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction) @@ -208,7 +203,7 @@ case class HashAggregateExecTransformer( case PartialMerge => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( VeloxAggregateFunctionsBuilder - .create(args, aggregateFunction, mixedPartialAndMerge, purePartialMerge), + .create(args, aggregateFunction, aggregateMode), childrenNodeList, modeKeyWord, VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction) @@ -216,7 +211,7 @@ case class HashAggregateExecTransformer( aggregateNodeList.add(aggFunctionNode) case Final => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode), childrenNodeList, modeKeyWord, ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable) @@ -233,7 +228,7 @@ case class HashAggregateExecTransformer( case Partial => // For Partial mode output type is binary. val partialNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode), childrenNodeList, modeKeyWord, ConverterUtils.getTypeNode( @@ -244,7 +239,7 @@ case class HashAggregateExecTransformer( case Final => // For Final mode output type is long. val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create(args, aggregateFunction), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode), childrenNodeList, modeKeyWord, ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable) @@ -257,11 +252,7 @@ case class HashAggregateExecTransformer( generateMergeCompanionNode() case _ => val aggFunctionNode = ExpressionBuilder.makeAggregateFunction( - VeloxAggregateFunctionsBuilder.create( - args, - aggregateFunction, - aggregateMode == PartialMerge && mixedPartialAndMerge, - aggregateMode == PartialMerge && purePartialMerge), + VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode), childrenNodeList, modeKeyWord, ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable) @@ -364,7 +355,8 @@ case class HashAggregateExecTransformer( val aggFunc = aggregateExpression.aggregateFunction val functionInputAttributes = aggFunc.inputAggBufferAttributes aggFunc match { - case _ if mixedPartialAndMerge && aggregateExpression.mode == Partial => + case _ + if aggregateExpression.mode == Partial => // FIXME: Any difference with the last branch? val childNodes = aggFunc.children .map( ExpressionConverter @@ -498,21 +490,6 @@ case class HashAggregateExecTransformer( operatorId) } - /** - * Whether this is a mixed aggregation of partial and partial-merge aggregation functions. - * @return - * whether partial and partial-merge functions coexist. - */ - def mixedPartialAndMerge: Boolean = { - val partialMergeExists = aggregateExpressions.exists(_.mode == PartialMerge) - val partialExists = aggregateExpressions.exists(_.mode == Partial) - partialMergeExists && partialExists - } - - def purePartialMerge: Boolean = { - aggregateExpressions.forall(_.mode == PartialMerge) - } - /** * Create and return the Rel for the this aggregation. * @param context @@ -589,8 +566,7 @@ object VeloxAggregateFunctionsBuilder { def create( args: java.lang.Object, aggregateFunc: AggregateFunction, - forMergeCompanion: Boolean = false, - purePartialMerge: Boolean = false): Long = { + mode: AggregateMode): Long = { val functionMap = args.asInstanceOf[JHashMap[String, JLong]] var sigName = ExpressionMappings.expressionsMap.get(aggregateFunc.getClass) @@ -606,22 +582,71 @@ object VeloxAggregateFunctionsBuilder { case _ => } - // Use companion function for partial-merge aggregation functions on count distinct. - val substraitAggFuncName = { - if (purePartialMerge) { - sigName.get + "_partial" - } else if (forMergeCompanion) { - sigName.get + "_merge" - } else { - sigName.get - } - } - ExpressionBuilder.newScalarFunction( functionMap, ConverterUtils.makeFuncName( - substraitAggFuncName, - VeloxIntermediateData.getInputTypes(aggregateFunc, forMergeCompanion), - FunctionConfig.REQ)) + // Substrait-to-Velox procedure will choose appropriate companion function if needed. + sigName.get, + VeloxIntermediateData.getInputTypes(aggregateFunc, mode == PartialMerge || mode == Final), + FunctionConfig.REQ + ) + ) + } +} + +// Hash aggregation that emits full-aggregated data, this works like regular hash +// aggregation in Vanilla Spark. +case class RegularHashAggregateExecTransformer( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends HashAggregateExecTransformer( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child) { + + override protected def allowFlush: Boolean = false + + override def simpleString(maxFields: Int): String = s"${super.simpleString(maxFields)}" + + override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = { + copy(child = newChild) + } +} + +// Hash aggregation that emits pre-aggregated data which allows duplications on grouping keys +// among its output rows. +case class FlushableHashAggregateExecTransformer( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends HashAggregateExecTransformer( + requiredChildDistributionExpressions, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + child) { + + override protected def allowFlush: Boolean = true + + override def simpleString(maxFields: Int): String = + s"Intermediate${super.simpleString(maxFields)}" + + override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = { + copy(child = newChild) } } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/FlushableHashAggregateRule.scala b/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/FlushableHashAggregateRule.scala new file mode 100644 index 000000000000..786ac9145af5 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/FlushableHashAggregateRule.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst + +import io.glutenproject.execution.{FlushableHashAggregateExecTransformer, RegularHashAggregateExecTransformer} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.aggregate.{Partial, PartialMerge} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + +/** + * To transform regular aggregation to intermediate aggregation that internally enables + * optimizations such as flushing and abandoning. + * + * Currently not in use. Will be enabled via a configuration after necessary verification is done. + */ +case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case shuffle: ShuffleExchangeLike => + // If an exchange follows a hash aggregate in which all functions are in partial mode, + // then it's safe to convert the hash aggregate to intermediate hash aggregate. + shuffle.child match { + case h: RegularHashAggregateExecTransformer => + if (h.aggregateExpressions.forall(p => p.mode == Partial || p.mode == PartialMerge)) { + shuffle.withNewChildren( + Seq(FlushableHashAggregateExecTransformer( + h.requiredChildDistributionExpressions, + h.groupingExpressions, + h.aggregateExpressions, + h.aggregateAttributes, + h.initialInputBufferOffset, + h.resultExpressions, + h.child + ))) + } else { + shuffle + } + } + } +} diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala index 9d5c4afed0d7..b1eb59605d1b 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala @@ -631,14 +631,27 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla } } - test("Support get native plan tree string") { + test("Support get native plan tree string, Velox single aggregation") { runQueryAndCompare("select l_partkey + 1, count(*) from lineitem group by l_partkey + 1") { df => val wholeStageTransformers = collect(df.queryExecution.executedPlan) { case w: WholeStageTransformer => w } val nativePlanString = wholeStageTransformers.head.nativePlanString() - assert(nativePlanString.contains("Aggregation[FINAL")) + assert(nativePlanString.contains("Aggregation[SINGLE")) + assert(nativePlanString.contains("TableScan")) + } + } + + // After IntermediateHashAggregateRule is enabled + ignore("Support get native plan tree string") { + runQueryAndCompare("select l_partkey + 1, count(*) from lineitem group by l_partkey + 1") { + df => + val wholeStageTransformers = collect(df.queryExecution.executedPlan) { + case w: WholeStageTransformer => w + } + val nativePlanString = wholeStageTransformers.head.nativePlanString() + assert(nativePlanString.contains("Aggregation[SINGLE")) assert(nativePlanString.contains("Aggregation[PARTIAL")) assert(nativePlanString.contains("TableScan")) } diff --git a/cpp/velox/compute/WholeStageResultIterator.cc b/cpp/velox/compute/WholeStageResultIterator.cc index e483f4a1162a..f079904eadec 100644 --- a/cpp/velox/compute/WholeStageResultIterator.cc +++ b/cpp/velox/compute/WholeStageResultIterator.cc @@ -39,8 +39,6 @@ const std::string kHiveConnectorId = "test-hive"; // memory const std::string kSpillStrategy = "spark.gluten.sql.columnar.backend.velox.spillStrategy"; const std::string kSpillStrategyDefaultValue = "auto"; -const std::string kPartialAggregationSpillEnabled = - "spark.gluten.sql.columnar.backend.velox.partialAggregationSpillEnabled"; const std::string kAggregationSpillEnabled = "spark.gluten.sql.columnar.backend.velox.aggregationSpillEnabled"; const std::string kJoinSpillEnabled = "spark.gluten.sql.columnar.backend.velox.joinSpillEnabled"; const std::string kOrderBySpillEnabled = "spark.gluten.sql.columnar.backend.velox.orderBySpillEnabled"; @@ -378,8 +376,6 @@ std::unordered_map WholeStageResultIterator::getQueryC } configs[velox::core::QueryConfig::kAggregationSpillEnabled] = std::to_string(veloxCfg_->get(kAggregationSpillEnabled, true)); - configs[velox::core::QueryConfig::kPartialAggregationSpillEnabled] = - std::to_string(veloxCfg_->get(kPartialAggregationSpillEnabled, true)); configs[velox::core::QueryConfig::kJoinSpillEnabled] = std::to_string(veloxCfg_->get(kJoinSpillEnabled, true)); configs[velox::core::QueryConfig::kOrderBySpillEnabled] = diff --git a/cpp/velox/substrait/SubstraitExtensionCollector.cc b/cpp/velox/substrait/SubstraitExtensionCollector.cc index 0096821ce498..472ef04e3fc5 100644 --- a/cpp/velox/substrait/SubstraitExtensionCollector.cc +++ b/cpp/velox/substrait/SubstraitExtensionCollector.cc @@ -28,15 +28,6 @@ int SubstraitExtensionCollector::getReferenceNumber( return getReferenceNumber({"", substraitFunctionSignature}); } -int SubstraitExtensionCollector::getReferenceNumber( - const std::string& functionName, - const std::vector& arguments, - const core::AggregationNode::Step /* aggregationStep */) { - // TODO: Ignore aggregationStep for now, will refactor when introduce velox - // registry for function signature binding - return getReferenceNumber(functionName, arguments); -} - template bool SubstraitExtensionCollector::BiDirectionHashMap::putIfAbsent(const int& key, const T& value) { if (forwardMap_.find(key) == forwardMap_.end() && reverseMap_.find(value) == reverseMap_.end()) { diff --git a/cpp/velox/substrait/SubstraitExtensionCollector.h b/cpp/velox/substrait/SubstraitExtensionCollector.h index 137bea35653d..32cb25c92a41 100644 --- a/cpp/velox/substrait/SubstraitExtensionCollector.h +++ b/cpp/velox/substrait/SubstraitExtensionCollector.h @@ -53,13 +53,6 @@ class SubstraitExtensionCollector { /// using ExtensionFunctionId. int getReferenceNumber(const std::string& functionName, const std::vector& arguments); - /// Given an aggregate function name and argument types and aggregation Step, - /// return the functionId using ExtensionFunctionId. - int getReferenceNumber( - const std::string& functionName, - const std::vector& arguments, - core::AggregationNode::Step aggregationStep); - /// Add extension functions to Substrait plan. void addExtensionsToPlan(::substrait::Plan* plan) const; diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index 0826ca050ce0..6717a3714bb2 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -194,26 +194,58 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::processEmit( } core::AggregationNode::Step SubstraitToVeloxPlanConverter::toAggregationStep(const ::substrait::AggregateRel& aggRel) { - if (aggRel.measures().size() == 0) { - // When only groupings exist, set the phase to be Single. - return core::AggregationNode::Step::kSingle; + // TODO Simplify Velox's aggregation steps + if (aggRel.has_advanced_extension() && + SubstraitParser::configSetInOptimization(aggRel.advanced_extension(), "allowFlush=")) { + return core::AggregationNode::Step::kPartial; } + return core::AggregationNode::Step::kSingle; +} - // Use the first measure to set aggregation phase. - const auto& firstMeasure = aggRel.measures()[0]; - const auto& aggFunction = firstMeasure.measure(); - switch (aggFunction.phase()) { +/// Get aggregation function step for AggregateFunction. +/// The returned step value will be used to decide which Velox aggregate function or companion function +/// is used for the actual data processing. +core::AggregationNode::Step SubstraitToVeloxPlanConverter::toAggregationFunctionStep( + const ::substrait::AggregateFunction& sAggFuc) { + const auto& phase = sAggFuc.phase(); + switch (phase) { + case ::substrait::AGGREGATION_PHASE_UNSPECIFIED: + VELOX_FAIL("Aggregation phase not specified.") + break; case ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE: return core::AggregationNode::Step::kPartial; case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_INTERMEDIATE: return core::AggregationNode::Step::kIntermediate; - case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT: - return core::AggregationNode::Step::kFinal; case ::substrait::AGGREGATION_PHASE_INITIAL_TO_RESULT: return core::AggregationNode::Step::kSingle; + case ::substrait::AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT: + return core::AggregationNode::Step::kFinal; + default: + VELOX_FAIL("Unexpected aggregation phase.") + } +} + +std::string SubstraitToVeloxPlanConverter::toAggregationFunctionName( + const std::string& baseName, + const core::AggregationNode::Step& step) { + std::string suffix; + switch (step) { + case core::AggregationNode::Step::kPartial: + suffix = "_partial"; + break; + case core::AggregationNode::Step::kFinal: + suffix = "_merge_extract"; + break; + case core::AggregationNode::Step::kIntermediate: + suffix = "_merge"; + break; + case core::AggregationNode::Step::kSingle: + suffix = ""; + break; default: - VELOX_FAIL("Aggregate phase is not supported."); + VELOX_FAIL("Unexpected aggregation node step.") } + return baseName + suffix; } core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::JoinRel& sJoin) { @@ -352,7 +384,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: } } const auto& aggFunction = measure.measure(); - auto funcName = SubstraitParser::findVeloxFunction(functionMap_, aggFunction.function_reference()); + auto baseFuncName = SubstraitParser::findVeloxFunction(functionMap_, aggFunction.function_reference()); + auto funcName = toAggregationFunctionName(baseFuncName, toAggregationFunctionStep(aggFunction)); std::vector aggParams; aggParams.reserve(aggFunction.arguments().size()); for (const auto& arg : aggFunction.arguments()) { diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h b/cpp/velox/substrait/SubstraitToVeloxPlan.h index 1d90bfbbb80a..4650383f6fbc 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.h +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h @@ -157,8 +157,18 @@ class SubstraitToVeloxPlanConverter { std::vector& rightExprs); /// Get aggregation step from AggregateRel. + /// If returned Partial, it means the aggregate generated can leveraging flushing and abandoning like + /// what streaming pre-aggregation can do in MPP databases. core::AggregationNode::Step toAggregationStep(const ::substrait::AggregateRel& sAgg); + /// Get aggregation function step for AggregateFunction. + /// The returned step value will be used to decide which Velox aggregate function or companion function + /// is used for the actual data processing. + core::AggregationNode::Step toAggregationFunctionStep(const ::substrait::AggregateFunction& sAggFuc); + + /// We use companion functions if the aggregate is not single. + std::string toAggregationFunctionName(const std::string& baseName, const core::AggregationNode::Step& step); + /// Helper Function to convert Substrait sortField to Velox sortingKeys and /// sortingOrders. std::pair, std::vector> processSortField( diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index 5bc424c0d840..bb9340fb0c6b 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -933,6 +933,7 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait for (const auto& smea : aggRel.measures()) { const auto& aggFunction = smea.measure(); + const auto& funcStep = planConverter_.toAggregationFunctionStep(aggFunction); auto funcSpec = planConverter_.findFuncSpec(aggFunction.function_reference()); std::vector types; bool isDecimal = false; @@ -949,7 +950,9 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait err.message()); return false; } - auto funcName = SubstraitParser::mapToVeloxFunction(SubstraitParser::getNameBeforeDelimiter(funcSpec), isDecimal); + auto baseFuncName = + SubstraitParser::mapToVeloxFunction(SubstraitParser::getNameBeforeDelimiter(funcSpec), isDecimal); + auto funcName = planConverter_.toAggregationFunctionName(baseFuncName, funcStep); auto signaturesOpt = exec::getAggregateFunctionSignatures(funcName); if (!signaturesOpt) { logValidateMsg( @@ -962,8 +965,7 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait exec::SignatureBinder binder(*signature, types); if (binder.tryBind()) { auto resolveType = binder.tryResolveType( - exec::isPartialOutput(planConverter_.toAggregationStep(aggRel)) ? signature->intermediateType() - : signature->returnType()); + exec::isPartialOutput(funcStep) ? signature->intermediateType() : signature->returnType()); if (resolveType == nullptr) { logValidateMsg( "native validation failed due to: Validation failed for function " + funcName + @@ -1069,70 +1071,28 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag // to be registered. static const std::unordered_set supportedAggFuncs = { "sum", - "sum_partial", - "sum_merge", "collect_set", "count", - "count_partial", - "count_merge", "avg", - "avg_partial", - "avg_merge", "min", - "min_partial", - "min_merge", "max", - "max_partial", - "max_merge", "min_by", - "min_by_partial", - "min_by_merge", "max_by", - "max_by_partial", - "max_by_merge", "stddev_samp", - "stddev_samp_partial", - "stddev_samp_merge", "stddev_pop", - "stddev_pop_partial", - "stddev_pop_merge", "bloom_filter_agg", "var_samp", - "var_samp_partial", - "var_samp_merge", "var_pop", - "var_pop_partial", - "var_pop_merge", "bit_and", - "bit_and_partial", - "bit_and_merge", "bit_or", - "bit_or_partial", - "bit_or_merge", "bit_xor", - "bit_xor_partial", - "bit_xor_merge", "first", - "first_partial", - "first_merge", "first_ignore_null", - "first_ignore_null_partial", - "first_ignore_null_merge", "last", - "last_partial", - "last_merge", "last_ignore_null", - "last_ignore_null_partial", - "last_ignore_null_merge", "corr", - "corr_partial", - "corr_merge", "covar_pop", - "covar_pop_partial", - "covar_pop_merge", "covar_samp", - "covar_samp_partial", - "covar_samp_merge", "approx_distinct"}; for (const auto& funcSpec : funcSpecs) { diff --git a/cpp/velox/substrait/VeloxSubstraitSignature.cc b/cpp/velox/substrait/VeloxSubstraitSignature.cc index ef1055f582cb..891628ac0d63 100644 --- a/cpp/velox/substrait/VeloxSubstraitSignature.cc +++ b/cpp/velox/substrait/VeloxSubstraitSignature.cc @@ -50,8 +50,20 @@ std::string VeloxSubstraitSignature::toSubstraitSignature(const TypePtr& type) { return "list"; case TypeKind::MAP: return "map"; - case TypeKind::ROW: - return "struct"; + case TypeKind::ROW: { + std::stringstream buffer; + buffer << "struct<"; + const auto& rt = asRowType(type); + for (size_t i = 0; i < rt->children().size(); i++) { + buffer << toSubstraitSignature(rt->childAt(i)); + if (i == rt->children().size() - 1) { + continue; + } + buffer << ","; + } + buffer << ">"; + return buffer.str(); + } case TypeKind::UNKNOWN: return "u!name"; default: diff --git a/cpp/velox/substrait/VeloxToSubstraitPlan.cc b/cpp/velox/substrait/VeloxToSubstraitPlan.cc index 409b85498ddc..21417c3b9ba4 100644 --- a/cpp/velox/substrait/VeloxToSubstraitPlan.cc +++ b/cpp/velox/substrait/VeloxToSubstraitPlan.cc @@ -16,11 +16,45 @@ */ #include "VeloxToSubstraitPlan.h" +#include +#include "utils/exception.h" namespace gluten { namespace { -::substrait::AggregationPhase toAggregationPhase(core::AggregationNode::Step step) { + +struct AggregateCompanion { + std::string functionName; + core::AggregationNode::Step step; +}; + +AggregateCompanion toAggregateCompanion(const core::AggregationNode::Aggregate& aggregate) { + const auto& companionName = aggregate.call->name(); + auto offset = companionName.find_last_of('_'); + if (offset == std::string::npos) { + return {companionName, core::AggregationNode::Step::kSingle}; + } + // found '_' + const auto& suffix = companionName.substr(offset + 1); + if (suffix.empty()) { + // the last char is '_' + return {companionName, core::AggregationNode::Step::kSingle}; + } + const auto& functionName = companionName.substr(0, offset); + if (suffix == "_partial") { + return {functionName, core::AggregationNode::Step::kPartial}; + } + if (suffix == "_merge_extract") { + return {functionName, core::AggregationNode::Step::kFinal}; + } + if (suffix == "_merge") { + return {functionName, core::AggregationNode::Step::kIntermediate}; + } + // others, not a companion function + return {companionName, core::AggregationNode::Step::kSingle}; +} + +::substrait::AggregationPhase toAggregationPhase(const core::AggregationNode::Step& step) { switch (step) { case core::AggregationNode::Step::kPartial: { return ::substrait::AGGREGATION_PHASE_INITIAL_TO_INTERMEDIATE; @@ -242,10 +276,26 @@ void VeloxToSubstraitPlanConvertor::toSubstrait( // Process measure, eg:sum(a). ::substrait::AggregateFunction* aggFunction = aggMeasures->mutable_measure(); - // Aggregation function name. - const auto& funName = aggregate.call->name(); - // set aggFunction args. + // Use aggregate node's step information to write advanced extension 'allowFlush'. + const auto& step = aggregateNode->step(); + switch (step) { + case core::AggregationNode::Step::kPartial: { + substrait::extensions::AdvancedExtension ae{}; + google::protobuf::StringValue msg; + msg.set_value("allowFlush=1"); + ae.mutable_optimization()->PackFrom(msg); + aggregateRel->mutable_advanced_extension()->MergeFrom(ae); + break; + } + case core::AggregationNode::Step::kSingle: + break; + case core::AggregationNode::Step::kFinal: + case core::AggregationNode::Step::kIntermediate: + VELOX_USER_FAIL("Step not supported"); + break; + } + // Set aggFunction args. std::vector arguments; arguments.reserve(aggregate.call->inputs().size()); for (const auto& expr : aggregate.call->inputs()) { @@ -260,15 +310,16 @@ void VeloxToSubstraitPlanConvertor::toSubstrait( } } + const auto& aggregateCompanion = toAggregateCompanion(aggregate); auto referenceNumber = - extensionCollector_->getReferenceNumber(funName, aggregate.rawInputTypes, aggregateNode->step()); + extensionCollector_->getReferenceNumber(aggregateCompanion.functionName, aggregate.rawInputTypes); aggFunction->set_function_reference(referenceNumber); aggFunction->mutable_output_type()->MergeFrom(typeConvertor_->toSubstraitType(arena, aggregate.call->type())); // Set substrait aggregate Function phase. - aggFunction->set_phase(toAggregationPhase(aggregateNode->step())); + aggFunction->set_phase(toAggregationPhase(aggregateCompanion.step)); } // Direct output. diff --git a/cpp/velox/tests/SubstraitExtensionCollectorTest.cc b/cpp/velox/tests/SubstraitExtensionCollectorTest.cc index 815d17b00328..b1a9ae9444f4 100644 --- a/cpp/velox/tests/SubstraitExtensionCollectorTest.cc +++ b/cpp/velox/tests/SubstraitExtensionCollectorTest.cc @@ -43,9 +43,9 @@ class SubstraitExtensionCollectorTest : public ::testing::Test { const std::string& functionName, std::vector&& arguments, core::AggregationNode::Step step) { - int referenceNumber1 = extensionCollector_->getReferenceNumber(functionName, arguments, step); + int referenceNumber1 = extensionCollector_->getReferenceNumber(functionName, arguments); // Repeat the call to make sure properly de-duplicated. - int referenceNumber2 = extensionCollector_->getReferenceNumber(functionName, arguments, step); + int referenceNumber2 = extensionCollector_->getReferenceNumber(functionName, arguments); EXPECT_EQ(referenceNumber1, referenceNumber2); return referenceNumber2; } @@ -116,7 +116,7 @@ TEST_F(SubstraitExtensionCollectorTest, addExtensionsToPlan) { ASSERT_EQ(getFunctionName(3), "array_sum:list"); ASSERT_EQ(getFunctionName(4), "sum:i32"); ASSERT_EQ(getFunctionName(5), "avg:i32"); - ASSERT_EQ(getFunctionName(6), "avg:struct"); + ASSERT_EQ(getFunctionName(6), "avg:struct"); ASSERT_EQ(getFunctionName(7), "count:i32"); } diff --git a/cpp/velox/tests/VeloxSubstraitRoundTripTest.cc b/cpp/velox/tests/VeloxSubstraitRoundTripTest.cc index cf5b72b4d366..68e79c80f5b9 100644 --- a/cpp/velox/tests/VeloxSubstraitRoundTripTest.cc +++ b/cpp/velox/tests/VeloxSubstraitRoundTripTest.cc @@ -18,6 +18,8 @@ #include #include +#include "operators/functions/RegistrationAllFunctions.h" + #include "velox/common/base/tests/GTestUtils.h" #include "velox/exec/tests/utils/OperatorTestBase.h" #include "velox/exec/tests/utils/PlanBuilder.h" @@ -186,6 +188,7 @@ TEST_F(VeloxSubstraitRoundTripTest, countAll) { } TEST_F(VeloxSubstraitRoundTripTest, sum) { + GTEST_SKIP(); // Only partial step and single step of aggregation is currently supported. auto vectors = makeVectors(2, 7, 3); createDuckDbTable(vectors); @@ -195,6 +198,7 @@ TEST_F(VeloxSubstraitRoundTripTest, sum) { } TEST_F(VeloxSubstraitRoundTripTest, sumAndCount) { + GTEST_SKIP(); // Only partial step and single step of aggregation is currently supported. auto vectors = makeVectors(2, 7, 3); createDuckDbTable(vectors); @@ -205,6 +209,7 @@ TEST_F(VeloxSubstraitRoundTripTest, sumAndCount) { } TEST_F(VeloxSubstraitRoundTripTest, sumGlobal) { + GTEST_SKIP(); // Only partial step and single step of aggregation is currently supported. auto vectors = makeVectors(2, 7, 3); createDuckDbTable(vectors); @@ -219,6 +224,7 @@ TEST_F(VeloxSubstraitRoundTripTest, sumGlobal) { } TEST_F(VeloxSubstraitRoundTripTest, sumMask) { + GTEST_SKIP(); // Only partial step and single step of aggregation is currently supported. auto vectors = makeVectors(2, 7, 3); createDuckDbTable(vectors); @@ -248,6 +254,7 @@ TEST_F(VeloxSubstraitRoundTripTest, rowConstructor) { } TEST_F(VeloxSubstraitRoundTripTest, projectAs) { + GTEST_SKIP(); // Only partial step and single step of aggregation is currently supported. RowVectorPtr vectors = makeRowVector( {makeFlatVector({0.905791934145, 0.968867771124}), makeFlatVector({2499109626526694126, 2342493223442167775}), @@ -264,6 +271,7 @@ TEST_F(VeloxSubstraitRoundTripTest, projectAs) { } TEST_F(VeloxSubstraitRoundTripTest, avg) { + GTEST_SKIP(); // Only partial step and single step of aggregation is currently supported. auto vectors = makeVectors(2, 7, 3); createDuckDbTable(vectors); @@ -447,10 +455,94 @@ TEST_F(VeloxSubstraitRoundTripTest, subField) { PlanBuilder().values({data}).project({"(cast(row_constructor(a, b) as row(a bigint, b bigint))).a"}).planNode(); assertFailingPlanConversion(plan, "Non-field expression is not supported"); } + +TEST_F(VeloxSubstraitRoundTripTest, sumCompanion) { + auto vectors = makeVectors(2, 7, 3); + createDuckDbTable(vectors); + + auto plan = PlanBuilder().values(vectors).singleAggregation({}, {"sum_partial(1)", "count_partial(c4)"}).planNode(); + + assertPlanConversion(plan, "SELECT sum(1), count(c4) FROM tmp"); +} + +TEST_F(VeloxSubstraitRoundTripTest, sumAndCountCompanion) { + auto vectors = makeVectors(2, 7, 3); + createDuckDbTable(vectors); + + auto plan = PlanBuilder() + .values(vectors) + .singleAggregation({}, {"sum_partial(c1)", "count_partial(c4)"}) + .singleAggregation({}, {"sum_merge_extract(a0)", "count_merge_extract(a1)"}) + .planNode(); + + assertPlanConversion(plan, "SELECT sum(c1), count(c4) FROM tmp"); +} + +TEST_F(VeloxSubstraitRoundTripTest, sumGlobalCompanion) { + auto vectors = makeVectors(2, 7, 3); + createDuckDbTable(vectors); + + // Global final aggregation. + auto plan = PlanBuilder() + .values(vectors) + .singleAggregation({"c0"}, {"sum_partial(c0)", "sum_partial(c1)"}) + .singleAggregation({"c0"}, {"sum_merge(a0)", "sum_merge(a1)"}) + .singleAggregation({"c0"}, {"sum_merge_extract(a0)", "sum_merge_extract(a1)"}) + .planNode(); + assertPlanConversion(plan, "SELECT c0, sum(c0), sum(c1) FROM tmp GROUP BY c0"); +} + +TEST_F(VeloxSubstraitRoundTripTest, sumMaskCompanion) { + auto vectors = makeVectors(2, 7, 3); + createDuckDbTable(vectors); + + auto plan = PlanBuilder() + .values(vectors) + .project({"c0", "c1", "c2 % 2 < 10 AS m0", "c3 % 3 = 0 AS m1"}) + .singleAggregation({}, {"sum_partial(c0)", "sum_partial(c0)", "sum_partial(c1)"}, {"m0", "m1", "m1"}) + .singleAggregation({}, {"sum_merge_extract(a0)", "sum_merge_extract(a1)", "sum_merge_extract(a2)"}) + .planNode(); + + assertPlanConversion( + plan, + "SELECT sum(c0) FILTER (WHERE c2 % 2 < 10), " + "sum(c0) FILTER (WHERE c3 % 3 = 0), sum(c1) FILTER (WHERE c3 % 3 = 0) " + "FROM tmp"); +} + +TEST_F(VeloxSubstraitRoundTripTest, projectAsCompanion) { + RowVectorPtr vectors = makeRowVector( + {makeFlatVector({0.905791934145, 0.968867771124}), + makeFlatVector({2499109626526694126, 2342493223442167775}), + makeFlatVector({581869302, -133711905})}); + createDuckDbTable({vectors}); + + auto plan = PlanBuilder() + .values({vectors}) + .filter("c0 < 0.5") + .project({"c1 * c2 as revenue"}) + .singleAggregation({}, {"sum_partial(revenue)"}) + .planNode(); + assertPlanConversion(plan, "SELECT sum(c1 * c2) as revenue FROM tmp WHERE c0 < 0.5"); +} + +TEST_F(VeloxSubstraitRoundTripTest, avgCompanion) { + auto vectors = makeVectors(2, 7, 3); + createDuckDbTable(vectors); + + auto plan = PlanBuilder() + .values(vectors) + .singleAggregation({}, {"avg_partial(c4)"}) + .singleAggregation({}, {"avg_merge_extract(a0)"}) + .planNode(); + + assertPlanConversion(plan, "SELECT avg(c4) FROM tmp"); +} + } // namespace gluten int main(int argc, char** argv) { - facebook::velox::functions::sparksql::registerFunctions(""); + gluten::registerAllFunctions(); testing::InitGoogleTest(&argc, argv); folly::init(&argc, &argv, false); return RUN_ALL_TESTS(); diff --git a/cpp/velox/tests/VeloxSubstraitSignatureTest.cc b/cpp/velox/tests/VeloxSubstraitSignatureTest.cc index fbfe14f7c92c..bbc1165add88 100644 --- a/cpp/velox/tests/VeloxSubstraitSignatureTest.cc +++ b/cpp/velox/tests/VeloxSubstraitSignatureTest.cc @@ -59,10 +59,10 @@ TEST_F(VeloxSubstraitSignatureTest, toSubstraitSignatureWithType) { ASSERT_EQ(toSubstraitSignature(ARRAY(BOOLEAN())), "list"); ASSERT_EQ(toSubstraitSignature(ARRAY(INTEGER())), "list"); ASSERT_EQ(toSubstraitSignature(MAP(INTEGER(), BIGINT())), "map"); - ASSERT_EQ(toSubstraitSignature(ROW({INTEGER(), BIGINT()})), "struct"); - ASSERT_EQ(toSubstraitSignature(ROW({ARRAY(INTEGER())})), "struct"); - ASSERT_EQ(toSubstraitSignature(ROW({MAP(INTEGER(), INTEGER())})), "struct"); - ASSERT_EQ(toSubstraitSignature(ROW({ROW({INTEGER()})})), "struct"); + ASSERT_EQ(toSubstraitSignature(ROW({INTEGER(), BIGINT()})), "struct"); + ASSERT_EQ(toSubstraitSignature(ROW({ARRAY(INTEGER())})), "struct"); + ASSERT_EQ(toSubstraitSignature(ROW({MAP(INTEGER(), INTEGER())})), "struct"); + ASSERT_EQ(toSubstraitSignature(ROW({ROW({INTEGER()})})), "struct>"); ASSERT_EQ(toSubstraitSignature(UNKNOWN()), "u!name"); } diff --git a/docs/get-started/Velox.md b/docs/get-started/Velox.md index 917bb0f072d1..f5b68deedb8c 100644 --- a/docs/get-started/Velox.md +++ b/docs/get-started/Velox.md @@ -327,7 +327,6 @@ Using the following configuration options to customize spilling: |--------------------------------------------------------------------------|---------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | spark.gluten.sql.columnar.backend.velox.spillStrategy | auto | none: Disable spill on Velox backend; auto: Let Spark memory manager manage Velox's spilling | | spark.gluten.sql.columnar.backend.velox.spillFileSystem | local | The filesystem used to store spill data. local: The local file system. heap-over-local: Write files to JVM heap if having extra heap space. Otherwise write to local file system. | -| spark.gluten.sql.columnar.backend.velox.partialAggregationSpillEnabled | true | Whether spill is enabled on partial aggregations | | spark.gluten.sql.columnar.backend.velox.aggregationSpillEnabled | true | Whether spill is enabled on aggregations | | spark.gluten.sql.columnar.backend.velox.joinSpillEnabled | true | Whether spill is enabled on joins | | spark.gluten.sql.columnar.backend.velox.orderBySpillEnabled | true | Whether spill is enabled on sorts | diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index 478cd17cd02f..694650a34306 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -17,7 +17,7 @@ set -exu VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_01_05 +VELOX_BRANCH=2024_01_05-2 VELOX_HOME="" #Set on run gluten on HDFS diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java index f7b0659a028b..eb952b8427f9 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java @@ -61,7 +61,7 @@ public AggregateFunction toProtobuf() { aggBuilder.setPhase(AggregationPhase.AGGREGATION_PHASE_INTERMEDIATE_TO_RESULT); break; default: - aggBuilder.setPhase(AggregationPhase.AGGREGATION_PHASE_UNSPECIFIED); + aggBuilder.setPhase(AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT); } } for (ExpressionNode expressionNode : expressionNodes) { diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index e925a0d887a1..5f4cd2c66050 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -565,17 +565,18 @@ abstract class HashAggregateExecBaseTransformer( null } - val isStreaming = if (isCapableForStreamingAggregation) { - "1" - } else { - "0" - } val optimization = BackendsApiManager.getTransformerApiInstance.packPBMessage( - StringValue.newBuilder.setValue(s"isStreaming=$isStreaming\n").build) + StringValue.newBuilder + .setValue(formatExtOptimizationString(isCapableForStreamingAggregation)) + .build) ExtensionBuilder.makeAdvancedExtension(optimization, enhancement) } + protected def formatExtOptimizationString(isStreaming: Boolean): String = { + s"isStreaming=${if (isStreaming) "1" else "0"}\n" + } + protected def getAggRel( context: SubstraitContext, operatorId: Long, diff --git a/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala b/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala index 9f4653451dfb..770ab72cdce2 100644 --- a/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala +++ b/gluten-core/src/test/scala/io/glutenproject/execution/WholeStageTransformerSuite.scala @@ -248,7 +248,7 @@ abstract class WholeStageTransformerSuite extends GlutenQueryTest with SharedSpa */ def checkOperatorMatch[T <: TransformSupport](df: DataFrame)(implicit tag: ClassTag[T]): Unit = { val executedPlan = getExecutedPlan(df) - assert(executedPlan.exists(plan => plan.getClass == tag.runtimeClass)) + assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan))) } /**