Skip to content

Commit

Permalink
[VL] Add FlushableHashAggregateExecTransformer to map Velox's partial…
Browse files Browse the repository at this point in the history
… aggregation which supports flushing and abandoning (#4130)
  • Loading branch information
zhztheplayer authored Jan 5, 2024
1 parent d1d976b commit d8248ea
Show file tree
Hide file tree
Showing 20 changed files with 397 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan): HashAggregateExecBaseTransformer =
HashAggregateExecTransformer(
RegularHashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
Expand Down Expand Up @@ -488,7 +488,8 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
*
* @return
*/
override def genExtendedColumnarPreRules(): List[SparkSession => Rule[SparkPlan]] = List()
override def genExtendedColumnarPreRules(): List[SparkSession => Rule[SparkPlan]] =
List()

/**
* Generate extended columnar post-rules.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

case class HashAggregateExecTransformer(
abstract class HashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
Expand Down Expand Up @@ -69,10 +69,6 @@ case class HashAggregateExecTransformer(
}
}

override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = {
copy(child = newChild)
}

/**
* Returns whether extracting subfield from struct is needed. True when the intermediate type of
* Velox aggregation is a compound type.
Expand Down Expand Up @@ -173,15 +169,16 @@ case class HashAggregateExecTransformer(
}
}

override protected def modeToKeyWord(aggregateMode: AggregateMode): String = {
super.modeToKeyWord(if (mixedPartialAndMerge) {
Partial
} else {
aggregateMode match {
case PartialMerge => Final
case _ => aggregateMode
}
})
// Whether the output data allows to be just pre-aggregated rather than
// fully aggregated. If true, aggregation could flush its in memory
// aggregated data whenever is needed rather than waiting for all input
// to be read.
protected def allowFlush: Boolean

override protected def formatExtOptimizationString(isStreaming: Boolean): String = {
val isStreamingStr = if (isStreaming) "1" else "0"
val allowFlushStr = if (allowFlush) "1" else "0"
s"isStreaming=$isStreamingStr\nallowFlush=$allowFlushStr\n"
}

// Create aggregate function node and add to list.
Expand All @@ -191,15 +188,13 @@ case class HashAggregateExecTransformer(
childrenNodeList: JList[ExpressionNode],
aggregateMode: AggregateMode,
aggregateNodeList: JList[AggregateFunctionNode]): Unit = {
// This is a special handling for PartialMerge in the execution of distinct.
// Use Partial phase instead for this aggregation.
val modeKeyWord = modeToKeyWord(aggregateMode)

def generateMergeCompanionNode(): Unit = {
aggregateMode match {
case Partial =>
val partialNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode),
childrenNodeList,
modeKeyWord,
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
Expand All @@ -208,15 +203,15 @@ case class HashAggregateExecTransformer(
case PartialMerge =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder
.create(args, aggregateFunction, mixedPartialAndMerge, purePartialMerge),
.create(args, aggregateFunction, aggregateMode),
childrenNodeList,
modeKeyWord,
VeloxIntermediateData.getIntermediateTypeNode(aggregateFunction)
)
aggregateNodeList.add(aggFunctionNode)
case Final =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)
Expand All @@ -233,7 +228,7 @@ case class HashAggregateExecTransformer(
case Partial =>
// For Partial mode output type is binary.
val partialNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(
Expand All @@ -244,7 +239,7 @@ case class HashAggregateExecTransformer(
case Final =>
// For Final mode output type is long.
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction),
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)
Expand All @@ -257,11 +252,7 @@ case class HashAggregateExecTransformer(
generateMergeCompanionNode()
case _ =>
val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(
args,
aggregateFunction,
aggregateMode == PartialMerge && mixedPartialAndMerge,
aggregateMode == PartialMerge && purePartialMerge),
VeloxAggregateFunctionsBuilder.create(args, aggregateFunction, aggregateMode),
childrenNodeList,
modeKeyWord,
ConverterUtils.getTypeNode(aggregateFunction.dataType, aggregateFunction.nullable)
Expand Down Expand Up @@ -364,7 +355,8 @@ case class HashAggregateExecTransformer(
val aggFunc = aggregateExpression.aggregateFunction
val functionInputAttributes = aggFunc.inputAggBufferAttributes
aggFunc match {
case _ if mixedPartialAndMerge && aggregateExpression.mode == Partial =>
case _
if aggregateExpression.mode == Partial => // FIXME: Any difference with the last branch?
val childNodes = aggFunc.children
.map(
ExpressionConverter
Expand Down Expand Up @@ -498,21 +490,6 @@ case class HashAggregateExecTransformer(
operatorId)
}

/**
* Whether this is a mixed aggregation of partial and partial-merge aggregation functions.
* @return
* whether partial and partial-merge functions coexist.
*/
def mixedPartialAndMerge: Boolean = {
val partialMergeExists = aggregateExpressions.exists(_.mode == PartialMerge)
val partialExists = aggregateExpressions.exists(_.mode == Partial)
partialMergeExists && partialExists
}

def purePartialMerge: Boolean = {
aggregateExpressions.forall(_.mode == PartialMerge)
}

/**
* Create and return the Rel for the this aggregation.
* @param context
Expand Down Expand Up @@ -589,8 +566,7 @@ object VeloxAggregateFunctionsBuilder {
def create(
args: java.lang.Object,
aggregateFunc: AggregateFunction,
forMergeCompanion: Boolean = false,
purePartialMerge: Boolean = false): Long = {
mode: AggregateMode): Long = {
val functionMap = args.asInstanceOf[JHashMap[String, JLong]]

var sigName = ExpressionMappings.expressionsMap.get(aggregateFunc.getClass)
Expand All @@ -606,22 +582,71 @@ object VeloxAggregateFunctionsBuilder {
case _ =>
}

// Use companion function for partial-merge aggregation functions on count distinct.
val substraitAggFuncName = {
if (purePartialMerge) {
sigName.get + "_partial"
} else if (forMergeCompanion) {
sigName.get + "_merge"
} else {
sigName.get
}
}

ExpressionBuilder.newScalarFunction(
functionMap,
ConverterUtils.makeFuncName(
substraitAggFuncName,
VeloxIntermediateData.getInputTypes(aggregateFunc, forMergeCompanion),
FunctionConfig.REQ))
// Substrait-to-Velox procedure will choose appropriate companion function if needed.
sigName.get,
VeloxIntermediateData.getInputTypes(aggregateFunc, mode == PartialMerge || mode == Final),
FunctionConfig.REQ
)
)
}
}

// Hash aggregation that emits full-aggregated data, this works like regular hash
// aggregation in Vanilla Spark.
case class RegularHashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends HashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child) {

override protected def allowFlush: Boolean = false

override def simpleString(maxFields: Int): String = s"${super.simpleString(maxFields)}"

override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = {
copy(child = newChild)
}
}

// Hash aggregation that emits pre-aggregated data which allows duplications on grouping keys
// among its output rows.
case class FlushableHashAggregateExecTransformer(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends HashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
resultExpressions,
child) {

override protected def allowFlush: Boolean = true

override def simpleString(maxFields: Int): String =
s"Intermediate${super.simpleString(maxFields)}"

override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = {
copy(child = newChild)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst

import io.glutenproject.execution.{FlushableHashAggregateExecTransformer, RegularHashAggregateExecTransformer}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.aggregate.{Partial, PartialMerge}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike

/**
* To transform regular aggregation to intermediate aggregation that internally enables
* optimizations such as flushing and abandoning.
*
* Currently not in use. Will be enabled via a configuration after necessary verification is done.
*/
case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case shuffle: ShuffleExchangeLike =>
// If an exchange follows a hash aggregate in which all functions are in partial mode,
// then it's safe to convert the hash aggregate to intermediate hash aggregate.
shuffle.child match {
case h: RegularHashAggregateExecTransformer =>
if (h.aggregateExpressions.forall(p => p.mode == Partial || p.mode == PartialMerge)) {
shuffle.withNewChildren(
Seq(FlushableHashAggregateExecTransformer(
h.requiredChildDistributionExpressions,
h.groupingExpressions,
h.aggregateExpressions,
h.aggregateAttributes,
h.initialInputBufferOffset,
h.resultExpressions,
h.child
)))
} else {
shuffle
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -631,14 +631,27 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
}
}

test("Support get native plan tree string") {
test("Support get native plan tree string, Velox single aggregation") {
runQueryAndCompare("select l_partkey + 1, count(*) from lineitem group by l_partkey + 1") {
df =>
val wholeStageTransformers = collect(df.queryExecution.executedPlan) {
case w: WholeStageTransformer => w
}
val nativePlanString = wholeStageTransformers.head.nativePlanString()
assert(nativePlanString.contains("Aggregation[FINAL"))
assert(nativePlanString.contains("Aggregation[SINGLE"))
assert(nativePlanString.contains("TableScan"))
}
}

// After IntermediateHashAggregateRule is enabled
ignore("Support get native plan tree string") {
runQueryAndCompare("select l_partkey + 1, count(*) from lineitem group by l_partkey + 1") {
df =>
val wholeStageTransformers = collect(df.queryExecution.executedPlan) {
case w: WholeStageTransformer => w
}
val nativePlanString = wholeStageTransformers.head.nativePlanString()
assert(nativePlanString.contains("Aggregation[SINGLE"))
assert(nativePlanString.contains("Aggregation[PARTIAL"))
assert(nativePlanString.contains("TableScan"))
}
Expand Down
4 changes: 0 additions & 4 deletions cpp/velox/compute/WholeStageResultIterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ const std::string kHiveConnectorId = "test-hive";
// memory
const std::string kSpillStrategy = "spark.gluten.sql.columnar.backend.velox.spillStrategy";
const std::string kSpillStrategyDefaultValue = "auto";
const std::string kPartialAggregationSpillEnabled =
"spark.gluten.sql.columnar.backend.velox.partialAggregationSpillEnabled";
const std::string kAggregationSpillEnabled = "spark.gluten.sql.columnar.backend.velox.aggregationSpillEnabled";
const std::string kJoinSpillEnabled = "spark.gluten.sql.columnar.backend.velox.joinSpillEnabled";
const std::string kOrderBySpillEnabled = "spark.gluten.sql.columnar.backend.velox.orderBySpillEnabled";
Expand Down Expand Up @@ -378,8 +376,6 @@ std::unordered_map<std::string, std::string> WholeStageResultIterator::getQueryC
}
configs[velox::core::QueryConfig::kAggregationSpillEnabled] =
std::to_string(veloxCfg_->get<bool>(kAggregationSpillEnabled, true));
configs[velox::core::QueryConfig::kPartialAggregationSpillEnabled] =
std::to_string(veloxCfg_->get<bool>(kPartialAggregationSpillEnabled, true));
configs[velox::core::QueryConfig::kJoinSpillEnabled] =
std::to_string(veloxCfg_->get<bool>(kJoinSpillEnabled, true));
configs[velox::core::QueryConfig::kOrderBySpillEnabled] =
Expand Down
9 changes: 0 additions & 9 deletions cpp/velox/substrait/SubstraitExtensionCollector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,6 @@ int SubstraitExtensionCollector::getReferenceNumber(
return getReferenceNumber({"", substraitFunctionSignature});
}

int SubstraitExtensionCollector::getReferenceNumber(
const std::string& functionName,
const std::vector<TypePtr>& arguments,
const core::AggregationNode::Step /* aggregationStep */) {
// TODO: Ignore aggregationStep for now, will refactor when introduce velox
// registry for function signature binding
return getReferenceNumber(functionName, arguments);
}

template <typename T>
bool SubstraitExtensionCollector::BiDirectionHashMap<T>::putIfAbsent(const int& key, const T& value) {
if (forwardMap_.find(key) == forwardMap_.end() && reverseMap_.find(value) == reverseMap_.end()) {
Expand Down
7 changes: 0 additions & 7 deletions cpp/velox/substrait/SubstraitExtensionCollector.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,6 @@ class SubstraitExtensionCollector {
/// using ExtensionFunctionId.
int getReferenceNumber(const std::string& functionName, const std::vector<TypePtr>& arguments);

/// Given an aggregate function name and argument types and aggregation Step,
/// return the functionId using ExtensionFunctionId.
int getReferenceNumber(
const std::string& functionName,
const std::vector<TypePtr>& arguments,
core::AggregationNode::Step aggregationStep);

/// Add extension functions to Substrait plan.
void addExtensionsToPlan(::substrait::Plan* plan) const;

Expand Down
Loading

0 comments on commit d8248ea

Please sign in to comment.