Skip to content

Commit

Permalink
implement window group limit
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Sep 11, 2024
1 parent 1950ac4 commit 064bc5d
Show file tree
Hide file tree
Showing 16 changed files with 714 additions and 272 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildSideRelation, HashedRelationBroadcastMode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil}
import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -909,4 +910,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
toScale: Int): DecimalType = {
SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType, toScale)
}

override def genWindowGroupLimitTransformer(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
rankLikeFunction: Expression,
limit: Int,
mode: WindowGroupLimitMode,
child: SparkPlan): SparkPlan =
CHWindowGroupLimitExecTransformer(
partitionSpec,
orderSpec,
rankLikeFunction,
limit,
mode,
child)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression._
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.window.{Final, Partial, WindowGroupLimitMode}

import com.google.protobuf.StringValue
import io.substrait.proto.SortField

import scala.collection.JavaConverters._

case class CHWindowGroupLimitExecTransformer(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
rankLikeFunction: Expression,
limit: Int,
mode: WindowGroupLimitMode,
child: SparkPlan)
extends UnaryTransformSupport {

@transient override lazy val metrics =
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetrics(sparkContext)

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetricsUpdater(metrics)

override def output: Seq[Attribute] = child.output

override def requiredChildDistribution: Seq[Distribution] = mode match {
case Partial => super.requiredChildDistribution
case Final =>
if (partitionSpec.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(partitionSpec) :: Nil
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
if (BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) {
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
} else {
Seq(Nil)
}
}

override def outputOrdering: Seq[SortOrder] = {
if (requiredChildOrdering.forall(_.isEmpty)) {
// The Velox backend `TopNRowNumber` does not require child ordering, because it
// uses hash table to store partition and use priority queue to track of top limit rows.
// Ideally, the output of `TopNRowNumber` is unordered but it is grouped for partition keys.
// To be safe, here we do not propagate the ordering.
// TODO: Make the framework aware of grouped data distribution
Nil
} else {
child.outputOrdering
}
}

override def outputPartitioning: Partitioning = child.outputPartitioning

def getWindowGroupLimitRel(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Partition By Expressions
val partitionsExpressions = partitionSpec
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
.doTransform(args))
.asJava

// Sort By Expressions
val sortFieldList =
orderSpec.map {
order =>
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq = child.output)
.doTransform(args)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
}.asJava
if (!validation) {
val windowFunction = rankLikeFunction match {
case _: RowNumber => ExpressionNames.ROW_NUMBER
case _: Rank => ExpressionNames.RANK
case _: DenseRank => ExpressionNames.DENSE_RANK
case _ => throw new GlutenNotSupportException(s"Unknow window function $rankLikeFunction")
}
val parametersStr = new StringBuffer("WindowGroupLimitParameters:")
parametersStr
.append("window_function=")
.append(windowFunction)
.append("\n")
val message = StringValue.newBuilder().setValue(parametersStr.toString).build()
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(message),
null)
RelBuilder.makeWindowGroupLimitRel(
input,
partitionsExpressions,
sortFieldList,
limit,
extensionNode,
context,
operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))

RelBuilder.makeWindowGroupLimitRel(
input,
partitionsExpressions,
sortFieldList,
limit,
extensionNode,
context,
operatorId)
}
}

override protected def doValidateInternal(): ValidationResult = {
if (!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction)) {
return ValidationResult
.failed(s"Found unsupported rank like function: $rankLikeFunction")
}
val substraitContext = new SubstraitContext
val operatorId = substraitContext.nextOperatorId(this.nodeName)

val relNode =
getWindowGroupLimitRel(substraitContext, child.output, operatorId, null, validation = true)

doNativeValidation(substraitContext, relNode)
}

override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)

val currRel =
getWindowGroupLimitRel(context, child.output, operatorId, childCtx.root, validation = false)
assert(currRel != null, "Window Group Limit Rel should be valid")
TransformContext(childCtx.outputAttributes, output, currRel)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
})

protected def fallbackSets(isAqe: Boolean): Set[Int] = {
if (isSparkVersionGE("3.5")) Set(44, 67, 70) else Set.empty[Int]
Set.empty[Int]
}
protected def excludedTpcdsQueries: Set[String] = Set(
"q66" // inconsistent results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1855,7 +1855,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
| ) t1
|) t2 where rank = 1
""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => }, isSparkVersionLE("3.3"))
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("GLUTEN-1874 not null in both streams") {
Expand All @@ -1873,7 +1873,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
| ) t1
|) t2 where rank = 1
""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => }, isSparkVersionLE("3.3"))
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("GLUTEN-2095: test cast(string as binary)") {
Expand Down Expand Up @@ -2456,7 +2456,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
| ) t1
|) t2 where rank = 1 order by p_partkey limit 100
|""".stripMargin
runQueryAndCompare(sql, noFallBack = isSparkVersionLE("3.3"))({ _ => })
runQueryAndCompare(sql, noFallBack = true)({ _ => })
}

test("GLUTEN-4190: crush on flattening a const null column") {
Expand Down

This file was deleted.

33 changes: 0 additions & 33 deletions cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h

This file was deleted.

Loading

0 comments on commit 064bc5d

Please sign in to comment.