Skip to content

Commit

Permalink
Add RewriteTypedImperativeAggregate rule for collect_list/collect_set
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Feb 24, 2024
1 parent b823592 commit 908cfa2
Showing 19 changed files with 157 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -443,4 +443,10 @@ object BackendSettings extends BackendSettingsApi {
override def supportCartesianProductExec(): Boolean = true

override def supportBroadcastNestedLoopJoinExec(): Boolean = true

override def shouldRewriteTypedImperativeAggregate(): Boolean = {
// The intermediate type of collect_list, collect_set in Velox backend is not consistent with
// vanilla Spark, we need to rewrite the aggregate to get the correct data type.
true
}
}
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ package io.glutenproject.execution
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.extension.RewriteTypedImperativeAggregate
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
@@ -799,14 +800,25 @@ case class HashAggregateExecPullOutHelper(
override protected def getAttrForAggregateExprs: List[Attribute] = {
aggregateExpressions.zipWithIndex.flatMap {
case (expr, index) =>
expr.mode match {
case Partial | PartialMerge =>
expr.aggregateFunction.aggBufferAttributes
case Final =>
Seq(aggregateAttributes(index))
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
handleSpecialAggregateAttr
.lift(expr)
.getOrElse(expr.mode match {
case Partial | PartialMerge =>
expr.aggregateFunction.aggBufferAttributes
case Final =>
Seq(aggregateAttributes(index))
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
})
}.toList
}

private val handleSpecialAggregateAttr: PartialFunction[AggregateExpression, Seq[Attribute]] = {
case ae: AggregateExpression if RewriteTypedImperativeAggregate.shouldRewrite(ae) =>
val aggBufferAttr = ae.aggregateFunction.inputAggBufferAttributes.head
Seq(
aggBufferAttr.copy(dataType = ae.aggregateFunction.dataType)(
aggBufferAttr.exprId,
aggBufferAttr.qualifier))
}
}
Original file line number Diff line number Diff line change
@@ -77,6 +77,9 @@ object VeloxIntermediateData {
aggregateFunc match {
case _ @Type(veloxDataTypes: Seq[DataType]) =>
Seq(StructType(veloxDataTypes.map(StructField("", _)).toArray))
case _: CollectList | _: CollectSet =>
// CollectList and CollectSet should use data type of agg function.
Seq(aggregateFunc.dataType)
case _ =>
// Not use StructType for single column agg intermediate data
aggregateFunc.aggBufferAttributes.map(_.dataType)
Original file line number Diff line number Diff line change
@@ -686,6 +686,17 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
}) == 4)
}
}
runQueryAndCompare(
"SELECT collect_list(DISTINCT n_name), count(*), collect_list(n_name) FROM nation") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 4)
}
}
}

test("count(1)") {
@@ -713,6 +724,13 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
|""".stripMargin)(
df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2))
}

test("collect_list null inputs") {
runQueryAndCompare("""
|select collect_list(a) from values (1), (-1), (null) AS tab(a)
|""".stripMargin)(
df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2))
}
}

class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite {
2 changes: 2 additions & 0 deletions cpp/velox/compute/WholeStageResultIterator.cc
Original file line number Diff line number Diff line change
@@ -461,6 +461,8 @@ std::unordered_map<std::string, std::string> WholeStageResultIterator::getQueryC
std::to_string(veloxCfg_->get<int32_t>(kAbandonPartialAggregationMinPct, 90));
configs[velox::core::QueryConfig::kAbandonPartialAggregationMinRows] =
std::to_string(veloxCfg_->get<int32_t>(kAbandonPartialAggregationMinRows, 100000));
// Spark's collect_set ignore nulls.
configs[velox::core::QueryConfig::kPrestoArrayAggIgnoreNulls] = std::to_string(true);
}
// Spill configs
if (spillStrategy_ == "none") {
8 changes: 7 additions & 1 deletion cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
@@ -375,7 +375,13 @@ std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunc
{"bit_and_merge", "bitwise_and_agg_merge"},
{"murmur3hash", "hash_with_seed"},
{"modulus", "remainder"},
{"date_format", "format_datetime"}};
{"date_format", "format_datetime"},
{"collect_set", "set_agg"},
{"collect_set_partial", "set_agg_partial"},
{"collect_set_merge", "set_agg_merge"},
{"collect_list", "array_agg"},
{"collect_list_partial", "array_agg_partial"},
{"collect_list_merge", "array_agg_merge"}};

const std::unordered_map<std::string, std::string> SubstraitParser::typeMap_ = {
{"bool", "BOOLEAN"},
1 change: 1 addition & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
@@ -1127,6 +1127,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
static const std::unordered_set<std::string> supportedAggFuncs = {
"sum",
"collect_set",
"collect_list",
"count",
"avg",
"min",
Original file line number Diff line number Diff line change
@@ -130,4 +130,6 @@ trait BackendSettingsApi {

/** Merge two phases hash based aggregate if need */
def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false

def shouldRewriteTypedImperativeAggregate(): Boolean = false
}
Original file line number Diff line number Diff line change
@@ -144,11 +144,6 @@ abstract class HashAggregateExecBaseTransformer(
mode: AggregateMode): Boolean = {
aggFunc match {
case s: Sum if s.prettyName.equals("try_sum") => false
case _: CollectList | _: CollectSet =>
mode match {
case Partial | Final | Complete => true
case _ => false
}
case bloom if bloom.getClass.getSimpleName.equals("BloomFilterAggregate") =>
mode match {
case Partial | Final | Complete => true
Original file line number Diff line number Diff line change
@@ -569,7 +569,11 @@ object ColumnarOverrideRules {
val GLUTEN_IS_ADAPTIVE_CONTEXT = "gluten.isAdaptiveContext"

def rewriteSparkPlanRule(): Rule[SparkPlan] = {
val rewriteRules = Seq(RewriteMultiChildrenCount, PullOutPreProject, PullOutPostProject)
val rewriteRules = Seq(
RewriteMultiChildrenCount,
RewriteTypedImperativeAggregate,
PullOutPreProject,
PullOutPostProject)
new RewriteSparkPlanRulesManager(rewriteRules)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.extension

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.utils.PullOutProjectHelper

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec

object RewriteTypedImperativeAggregate extends Rule[SparkPlan] with PullOutProjectHelper {
private lazy val shouldRewriteTypedImperativeAggregate =
BackendsApiManager.getSettings.shouldRewriteTypedImperativeAggregate()

def shouldRewrite(ae: AggregateExpression): Boolean = {
ae.aggregateFunction match {
case _: CollectList | _: CollectSet =>
ae.mode match {
case Partial | PartialMerge => true
case _ => false
}
case _ => false
}
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!shouldRewriteTypedImperativeAggregate) {
return plan
}

plan match {
case agg: BaseAggregateExec if agg.aggregateExpressions.exists(shouldRewrite) =>
val exprMap = agg.aggregateExpressions
.filter(shouldRewrite)
.map(ae => ae.aggregateFunction.inputAggBufferAttributes.head -> ae)
.toMap
val newResultExpressions = agg.resultExpressions.map {
case attr: AttributeReference =>
exprMap
.get(attr)
.map {
ae =>
attr.copy(dataType = ae.aggregateFunction.dataType)(
exprId = attr.exprId,
qualifier = attr.qualifier
)
}
.getOrElse(attr)
case other => other
}
copyBaseAggregateExec(agg)(newResultExpressions = newResultExpressions)

case _ => plan
}
}
}
Original file line number Diff line number Diff line change
@@ -220,6 +220,7 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("from_unixtime")
enableSuite[GlutenDecimalExpressionSuite]
enableSuite[GlutenStringFunctionsSuite]
.exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns")
enableSuite[GlutenRegexpExpressionsSuite]
enableSuite[GlutenNullExpressionsSuite]
enableSuite[GlutenPredicateSuite]
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql
import io.glutenproject.execution.HashAggregateExecBaseTransformer

import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestData.DecimalData
@@ -340,13 +340,10 @@ class GlutenDataFrameAggregateSuite extends DataFrameAggregateSuite with GlutenS
// test case for ObjectHashAggregate and SortAggregate
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
objHashAggOrSortAggDF.collect()
val objHashAggOrSortAggPlan =
stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan)
if (useObjectHashAgg) {
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
} else {
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
}
assert(stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan).find {
case _: HashAggregateExecBaseTransformer => true
case _ => false
}.isDefined)
}
}
}
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ import org.apache.spark.sql.sources.{GlutenBucketedReadWithoutHiveSupportSuite,

class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenStringFunctionsSuite]
.exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns")
enableSuite[GlutenBloomFilterAggregateQuerySuite]
enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite]
enableSuite[GlutenDataSourceV2DataFrameSuite]
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ import io.glutenproject.execution.HashAggregateExecBaseTransformer

import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -262,13 +262,10 @@ class GlutenDataFrameAggregateSuite extends DataFrameAggregateSuite with GlutenS
// test case for ObjectHashAggregate and SortAggregate
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
objHashAggOrSortAggDF.collect()
val objHashAggOrSortAggPlan =
stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan)
if (useObjectHashAgg) {
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
} else {
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
}
assert(stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan).find {
case _: HashAggregateExecBaseTransformer => true
case _ => false
}.isDefined)
}
}
}
Original file line number Diff line number Diff line change
@@ -73,12 +73,7 @@ class GlutenReplaceHashWithSortAggSuite
|)
|GROUP BY key
""".stripMargin
aggExpr match {
case "FIRST" =>
checkAggs(query, 2, 0, 2, 0)
case _ =>
checkAggs(query, 1, 1, 2, 0)
}
checkAggs(query, 2, 0, 2, 0)
}
}
}
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ import org.apache.spark.sql.sources.{GlutenBucketedReadWithoutHiveSupportSuite,

class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenStringFunctionsSuite]
.exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns")
enableSuite[GlutenBloomFilterAggregateQuerySuite]
enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite]
enableSuite[GlutenDataSourceV2DataFrameSuite]
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ import io.glutenproject.execution.HashAggregateExecBaseTransformer

import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -262,13 +262,10 @@ class GlutenDataFrameAggregateSuite extends DataFrameAggregateSuite with GlutenS
// test case for ObjectHashAggregate and SortAggregate
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
objHashAggOrSortAggDF.collect()
val objHashAggOrSortAggPlan =
stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan)
if (useObjectHashAgg) {
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
} else {
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
}
assert(stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan).find {
case _: HashAggregateExecBaseTransformer => true
case _ => false
}.isDefined)
}
}
}
Original file line number Diff line number Diff line change
@@ -73,12 +73,7 @@ class GlutenReplaceHashWithSortAggSuite
|)
|GROUP BY key
""".stripMargin
aggExpr match {
case "FIRST" =>
checkAggs(query, 2, 0, 2, 0)
case _ =>
checkAggs(query, 1, 1, 2, 0)
}
checkAggs(query, 2, 0, 2, 0)
}
}
}

0 comments on commit 908cfa2

Please sign in to comment.