Skip to content

Commit

Permalink
Support collect_list collect_set
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Feb 23, 2024
1 parent 996ff4c commit dbaabba
Show file tree
Hide file tree
Showing 16 changed files with 148 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -441,4 +441,10 @@ object BackendSettings extends BackendSettingsApi {
}

override def supportCartesianProductExec(): Boolean = true

override def shouldRewriteTypedImperativeAggregate(): Boolean = {
// The intermediate type of collect_list, collect_set in Velox backend is not consistent with
// vanilla Spark, we need to rewrite the aggregate to get the correct data type.
true
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.glutenproject.execution
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.extension.RewriteTypedImperativeAggregate
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.{AggregationParams, SubstraitContext}
import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
Expand Down Expand Up @@ -799,14 +800,25 @@ case class HashAggregateExecPullOutHelper(
override protected def getAttrForAggregateExprs: List[Attribute] = {
aggregateExpressions.zipWithIndex.flatMap {
case (expr, index) =>
expr.mode match {
case Partial | PartialMerge =>
expr.aggregateFunction.aggBufferAttributes
case Final =>
Seq(aggregateAttributes(index))
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
handleSpecialAggregateAttr
.lift(expr)
.getOrElse(expr.mode match {
case Partial | PartialMerge =>
expr.aggregateFunction.aggBufferAttributes
case Final =>
Seq(aggregateAttributes(index))
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
})
}.toList
}

private val handleSpecialAggregateAttr: PartialFunction[AggregateExpression, Seq[Attribute]] = {
case ae: AggregateExpression if RewriteTypedImperativeAggregate.shouldRewrite(ae) =>
val aggBufferAttr = ae.aggregateFunction.inputAggBufferAttributes.head
Seq(
aggBufferAttr.copy(dataType = ae.aggregateFunction.dataType)(
aggBufferAttr.exprId,
aggBufferAttr.qualifier))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ object VeloxIntermediateData {
aggregateFunc match {
case _ @Type(veloxDataTypes: Seq[DataType]) =>
Seq(StructType(veloxDataTypes.map(StructField("", _)).toArray))
case _: CollectList | _: CollectSet =>
// CollectList and CollectSet should use data type of agg function.
Seq(aggregateFunc.dataType)
case _ =>
// Not use StructType for single column agg intermediate data
aggregateFunc.aggBufferAttributes.map(_.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.sql.sources.useV1SourceList", "avro")
.set("spark.eventLog.dir", "hdfs://master-1-1:9000/spark-history/c-5a28e91deddfe64b")
.set("spark.eventLog.enabled", "true")
}

test("count") {
Expand Down Expand Up @@ -686,6 +688,17 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
}) == 4)
}
}
runQueryAndCompare(
"SELECT collect_list(DISTINCT n_name), count(*), collect_list(n_name) FROM nation") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 4)
}
}
}

test("count(1)") {
Expand Down Expand Up @@ -713,6 +726,13 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
|""".stripMargin)(
df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2))
}

test("collect_list null inputs") {
runQueryAndCompare("""
|select collect_list(a) from values (1), (-1), (null) AS tab(a)
|""".stripMargin)(
df => assert(getExecutedPlan(df).count(_.isInstanceOf[HashAggregateExecTransformer]) == 2))
}
}

class VeloxAggregateFunctionsDefaultSuite extends VeloxAggregateFunctionsSuite {
Expand Down
2 changes: 2 additions & 0 deletions cpp/velox/compute/WholeStageResultIterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ std::unordered_map<std::string, std::string> WholeStageResultIterator::getQueryC
std::to_string(veloxCfg_->get<int32_t>(kAbandonPartialAggregationMinPct, 90));
configs[velox::core::QueryConfig::kAbandonPartialAggregationMinRows] =
std::to_string(veloxCfg_->get<int32_t>(kAbandonPartialAggregationMinRows, 100000));
// Spark's collect_set ignore nulls.
configs[velox::core::QueryConfig::kPrestoArrayAggIgnoreNulls] = std::to_string(true);
}
// Spill configs
if (spillStrategy_ == "none") {
Expand Down
8 changes: 7 additions & 1 deletion cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,13 @@ std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunc
{"bit_and_merge", "bitwise_and_agg_merge"},
{"murmur3hash", "hash_with_seed"},
{"modulus", "remainder"},
{"date_format", "format_datetime"}};
{"date_format", "format_datetime"},
{"collect_set", "set_agg"},
{"collect_set_partial", "set_agg_partial"},
{"collect_set_merge", "set_agg_merge"},
{"collect_list", "array_agg"},
{"collect_list_partial", "array_agg_partial"},
{"collect_list_merge", "array_agg_merge"}};

const std::unordered_map<std::string, std::string> SubstraitParser::typeMap_ = {
{"bool", "BOOLEAN"},
Expand Down
1 change: 1 addition & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag
static const std::unordered_set<std::string> supportedAggFuncs = {
"sum",
"collect_set",
"collect_list",
"count",
"avg",
"min",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,6 @@ trait BackendSettingsApi {

/** Merge two phases hash based aggregate if need */
def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false

def shouldRewriteTypedImperativeAggregate(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,6 @@ abstract class HashAggregateExecBaseTransformer(
mode: AggregateMode): Boolean = {
aggFunc match {
case s: Sum if s.prettyName.equals("try_sum") => false
case _: CollectList | _: CollectSet =>
mode match {
case Partial | Final | Complete => true
case _ => false
}
case bloom if bloom.getClass.getSimpleName.equals("BloomFilterAggregate") =>
mode match {
case Partial | Final | Complete => true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,11 @@ object ColumnarOverrideRules {
val GLUTEN_IS_ADAPTIVE_CONTEXT = "gluten.isAdaptiveContext"

def rewriteSparkPlanRule(): Rule[SparkPlan] = {
val rewriteRules = Seq(RewriteMultiChildrenCount, PullOutPreProject, PullOutPostProject)
val rewriteRules = Seq(
RewriteMultiChildrenCount,
RewriteTypedImperativeAggregate,
PullOutPreProject,
PullOutPostProject)
new RewriteSparkPlanRulesManager(rewriteRules)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.extension

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.utils.PullOutProjectHelper

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec

object RewriteTypedImperativeAggregate extends Rule[SparkPlan] with PullOutProjectHelper {
private lazy val shouldRewriteTypedImperativeAggregate =
BackendsApiManager.getSettings.shouldRewriteTypedImperativeAggregate()

def shouldRewrite(ae: AggregateExpression): Boolean = {
ae.aggregateFunction match {
case _: CollectList | _: CollectSet =>
ae.mode match {
case Partial | PartialMerge => true
case _ => false
}
case _ => false
}
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!shouldRewriteTypedImperativeAggregate) {
return plan
}

plan match {
case agg: BaseAggregateExec if agg.aggregateExpressions.exists(shouldRewrite) =>
val exprMap = agg.aggregateExpressions
.filter(shouldRewrite)
.map(ae => ae.aggregateFunction.inputAggBufferAttributes.head -> ae)
.toMap
val newResultExpressions = agg.resultExpressions.map {
case attr: AttributeReference =>
exprMap
.get(attr)
.map {
ae =>
attr.copy(dataType = ae.aggregateFunction.dataType)(
exprId = attr.exprId,
qualifier = attr.qualifier
)
}
.getOrElse(attr)
case other => other
}
copyBaseAggregateExec(agg)(newResultExpressions = newResultExpressions)

case _ => plan
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class VeloxTestSettings extends BackendTestSettings {
.exclude("from_unixtime")
enableSuite[GlutenDecimalExpressionSuite]
enableSuite[GlutenStringFunctionsSuite]
.exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns")
enableSuite[GlutenRegexpExpressionsSuite]
enableSuite[GlutenNullExpressionsSuite]
enableSuite[GlutenPredicateSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import org.apache.spark.sql.sources.{GlutenBucketedReadWithoutHiveSupportSuite,

class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenStringFunctionsSuite]
.exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns")
enableSuite[GlutenBloomFilterAggregateQuerySuite]
enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite]
enableSuite[GlutenDataSourceV2DataFrameSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import io.glutenproject.execution.HashAggregateExecBaseTransformer

import org.apache.spark.sql.GlutenTestConstants.GLUTEN_TEST
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -262,13 +262,10 @@ class GlutenDataFrameAggregateSuite extends DataFrameAggregateSuite with GlutenS
// test case for ObjectHashAggregate and SortAggregate
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
objHashAggOrSortAggDF.collect()
val objHashAggOrSortAggPlan =
stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan)
if (useObjectHashAgg) {
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
} else {
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
}
assert(stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan).find {
case _: HashAggregateExecBaseTransformer => true
case _ => false
}.isDefined)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,7 @@ class GlutenReplaceHashWithSortAggSuite
|)
|GROUP BY key
""".stripMargin
aggExpr match {
case "FIRST" =>
checkAggs(query, 2, 0, 2, 0)
case _ =>
checkAggs(query, 1, 1, 2, 0)
}
checkAggs(query, 2, 0, 2, 0)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import org.apache.spark.sql.sources.{GlutenBucketedReadWithoutHiveSupportSuite,

class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenStringFunctionsSuite]
.exclude("SPARK-31993: concat_ws in agg function with plenty of string/array types columns")
enableSuite[GlutenBloomFilterAggregateQuerySuite]
enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite]
enableSuite[GlutenDataSourceV2DataFrameSuite]
Expand Down

0 comments on commit dbaabba

Please sign in to comment.