diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala index ad997843eee6..aebc884798ae 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashAggregateExecTransformer.scala @@ -18,10 +18,9 @@ package io.glutenproject.execution import io.glutenproject.execution.CHHashAggregateExecTransformer.getAggregateResultAttributes import io.glutenproject.expression._ -import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} +import io.glutenproject.substrait.`type`.TypeNode import io.glutenproject.substrait.{AggregationParams, SubstraitContext} import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode} -import io.glutenproject.substrait.extensions.ExtensionBuilder import io.glutenproject.substrait.rel.{LocalFilesBuilder, RelBuilder, RelNode} import org.apache.spark.sql.catalyst.expressions._ @@ -30,8 +29,6 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.types._ -import com.google.protobuf.Any - import java.util object CHHashAggregateExecTransformer { @@ -285,31 +282,16 @@ case class CHHashAggregateExecTransformer( ) aggregateFunctionList.add(aggFunctionNode) }) - if (!validation) { - RelBuilder.makeAggregateRel( - input, - groupingList, - aggregateFunctionList, - aggFilterList, - context, - operatorId) - } else { - // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeAggregateRel( - input, - groupingList, - aggregateFunctionList, - aggFilterList, - extensionNode, - context, - operatorId) - } + + val extensionNode = getAdvancedExtension(validation, originalInputAttributes) + RelBuilder.makeAggregateRel( + input, + groupingList, + aggregateFunctionList, + aggFilterList, + extensionNode, + context, + operatorId) } override def isStreaming: Boolean = false diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index 2d12eae0d41f..7e4a85eb11f1 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -654,11 +654,14 @@ case class HashAggregateExecTransformer( } addFunctionNode(args, aggregateFunc, childrenNodes, aggExpr.mode, aggregateFunctionList) }) + + val extensionNode = getAdvancedExtension() RelBuilder.makeAggregateRel( projectRel, groupingList, aggregateFunctionList, aggFilterList, + extensionNode, context, operatorId) } diff --git a/backends-velox/src/test/scala/org/apache/spark/sql/execution/benchmark/StreamingAggregateBenchmark.scala b/backends-velox/src/test/scala/org/apache/spark/sql/execution/benchmark/StreamingAggregateBenchmark.scala new file mode 100644 index 000000000000..dc826f486ae4 --- /dev/null +++ b/backends-velox/src/test/scala/org/apache/spark/sql/execution/benchmark/StreamingAggregateBenchmark.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.benchmark + +import io.glutenproject.GlutenConfig + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmark to measure performance for streaming aggregate. To run this benchmark: + * {{{ + * bin/spark-submit --class --jars + * }}} + */ +object StreamingAggregateBenchmark extends SqlBasedBenchmark { + private val numRows = { + spark.sparkContext.conf.getLong("spark.gluten.benchmark.rows", 8 * 1000 * 1000) + } + + private val mode = { + spark.sparkContext.conf.getLong("spark.gluten.benchmark.remainder", 4 * 1000 * 1000) + } + + private def doBenchmark(): Unit = { + val benchmark = new Benchmark("streaming aggregate", numRows, output = output) + + val query = + """ + |SELECT c1, count(*), sum(c2) FROM ( + |SELECT t1.c1, t2.c2 FROM t t1 JOIN t t2 ON t1.c1 = t2.c1 + |) + |GROUP BY c1 + |""".stripMargin + benchmark.addCase(s"Enable streaming aggregate", 3) { + _ => + withSQLConf( + GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false" + ) { + spark.sql(query).noop() + } + } + + benchmark.addCase(s"Disable streaming aggregate", 3) { + _ => + withSQLConf( + GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + GlutenConfig.COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED.key -> "false" + ) { + spark.sql(query).noop() + } + } + + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + spark + .range(numRows) + .selectExpr(s"id % $mode as c1", "id as c2") + .write + .saveAsTable("t") + + try { + doBenchmark() + } finally { + spark.sql("DROP TABLE t") + } + } +} diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index f38c2eefacca..9c94f7d42a2c 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -365,6 +365,11 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: bool ignoreNullKeys = false; std::vector preGroupingExprs; + if (aggRel.has_advanced_extension() && + SubstraitParser::configSetInOptimization(aggRel.advanced_extension(), "isStreaming=")) { + preGroupingExprs.reserve(veloxGroupingExprs.size()); + preGroupingExprs.insert(preGroupingExprs.begin(), veloxGroupingExprs.begin(), veloxGroupingExprs.end()); + } // Get the output names of Aggregation. std::vector aggOutNames; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/RelBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/RelBuilder.java index 8dfb2f4a20af..93557ef6f108 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/RelBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/RelBuilder.java @@ -81,17 +81,6 @@ public static RelNode makeProjectRel( return new ProjectRelNode(input, expressionNodes, extensionNode, emitStartIndex); } - public static RelNode makeAggregateRel( - RelNode input, - List groupings, - List aggregateFunctionNodes, - List filters, - SubstraitContext context, - Long operatorId) { - context.registerRelToOperator(operatorId); - return new AggregateRelNode(input, groupings, aggregateFunctionNodes, filters); - } - public static RelNode makeAggregateRel( RelNode input, List groupings, diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala index 74ded1bdc224..89a28307e137 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala @@ -111,6 +111,8 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP } } + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override protected def doValidateInternal(): ValidationResult = { if (cond == null) { // The computing of this Filter is not needed. @@ -181,6 +183,7 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP case class ProjectExecTransformer private (projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryTransformSupport with PredicateHelper + with AliasAwareOutputOrdering with Logging { // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. @@ -189,6 +192,10 @@ case class ProjectExecTransformer private (projectList: Seq[NamedExpression], ch val sparkConf: SparkConf = sparkContext.getConf + override protected def orderingExpressions: Seq[SortOrder] = child.outputOrdering + + override protected def outputExpressions: Seq[NamedExpression] = projectList + override protected def doValidateInternal(): ValidationResult = { val substraitContext = new SubstraitContext // Firstly, need to check if the Substrait plan for this operator can be successfully generated. diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index 25298c53f710..358c93de3481 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -16,6 +16,7 @@ */ package io.glutenproject.execution +import io.glutenproject.GlutenConfig import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.expression._ import io.glutenproject.extension.ValidationResult @@ -23,7 +24,7 @@ import io.glutenproject.metrics.MetricsUpdater import io.glutenproject.substrait.`type`.TypeBuilder import io.glutenproject.substrait.{AggregationParams, SubstraitContext} import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode} -import io.glutenproject.substrait.extensions.ExtensionBuilder +import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder} import io.glutenproject.substrait.rel.{RelBuilder, RelNode} import org.apache.spark.rdd.RDD @@ -35,7 +36,7 @@ import org.apache.spark.sql.execution.aggregate._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -import com.google.protobuf.Any +import com.google.protobuf.{Any, StringValue} import java.util.{ArrayList => JArrayList, List => JList} @@ -73,6 +74,20 @@ abstract class HashAggregateExecBaseTransformer( aggregateAttributes) } + protected def isGroupingKeysPreGrouped: Boolean = { + if (!conf.getConf(GlutenConfig.COLUMNAR_PREFER_STREAMING_AGGREGATE)) { + return false + } + val childOrdering = child match { + case agg: HashAggregateExecBaseTransformer + if agg.groupingExpressions == this.groupingExpressions => + agg.child.outputOrdering + case _ => child.outputOrdering + } + val requiredOrdering = groupingExpressions.map(expr => SortOrder.apply(expr, Ascending)) + SortOrder.orderingSatisfies(childOrdering, requiredOrdering) + } + override def doExecuteColumnar(): RDD[ColumnarBatch] = { throw new UnsupportedOperationException(s"This operator doesn't support doExecuteColumnar().") } @@ -328,11 +343,13 @@ abstract class HashAggregateExecBaseTransformer( } }) + val extensionNode = getAdvancedExtension() RelBuilder.makeAggregateRel( inputRel, groupingList, aggregateFunctionList, aggFilterList, + extensionNode, context, operatorId) } @@ -533,30 +550,39 @@ abstract class HashAggregateExecBaseTransformer( aggExpr.mode, aggregateFunctionList) }) - if (!validation) { - RelBuilder.makeAggregateRel( - input, - groupingList, - aggregateFunctionList, - aggFilterList, - context, - operatorId) - } else { + + val extensionNode = getAdvancedExtension(validation, originalInputAttributes) + RelBuilder.makeAggregateRel( + input, + groupingList, + aggregateFunctionList, + aggFilterList, + extensionNode, + context, + operatorId) + } + + protected def getAdvancedExtension( + validation: Boolean = false, + originalInputAttributes: Seq[Attribute] = Seq.empty): AdvancedExtensionNode = { + val enhancement = if (validation) { // Use a extension node to send the input types through Substrait plan for validation. val inputTypeNodeList = originalInputAttributes .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) .asJava - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeAggregateRel( - input, - groupingList, - aggregateFunctionList, - aggFilterList, - extensionNode, - context, - operatorId) + Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf) + } else { + null + } + + val isStreaming = if (isGroupingKeysPreGrouped) { + "1" + } else { + "0" } + val optimization = + Any.pack(StringValue.newBuilder.setValue(s"isStreaming=$isStreaming\n").build) + ExtensionBuilder.makeAdvancedExtension(optimization, enhancement) } protected def getAggRel( diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala index 0922ac4088e2..72b761729b3b 100644 --- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala +++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala @@ -644,6 +644,17 @@ object GlutenConfig { .checkValues(Set("streaming", "sort")) .createWithDefault("streaming") + val COLUMNAR_PREFER_STREAMING_AGGREGATE = + buildConf("spark.gluten.sql.columnar.preferStreamingAggregate") + .internal() + .doc( + "Velox backend supports `StreamingAggregate`. `StreamingAggregate` uses the less " + + "memory as it does not need to hold all groups in memory, so it could avoid spill. " + + "When true and the child output ordering satisfies the grouping key then " + + "Gluten will choose `StreamingAggregate` as the native operator.") + .booleanConf + .createWithDefault(true) + val COLUMNAR_FPRCE_SHUFFLED_HASH_JOIN_ENABLED = buildConf("spark.gluten.sql.columnar.forceShuffledHashJoin") .internal()