Skip to content

Commit

Permalink
implement window group limit
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Sep 10, 2024
1 parent 27d134b commit 7205c3f
Show file tree
Hide file tree
Showing 17 changed files with 730 additions and 263 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildSideRelation, HashedRelationBroadcastMode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil}
import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -883,4 +884,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
toScale: Int): DecimalType = {
SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType, toScale)
}

override def genWindowGroupLimitTransformer(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
rankLikeFunction: Expression,
limit: Int,
mode: WindowGroupLimitMode,
child: SparkPlan): SparkPlan =
CHWindowGroupLimitExecTransformer(
partitionSpec,
orderSpec,
rankLikeFunction,
limit,
mode,
child)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression._
import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter}
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
import org.apache.gluten.substrait.`type`.TypeBuilder
import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.window.{Final, Partial, WindowGroupLimitMode}

import com.google.protobuf.StringValue
import io.substrait.proto.SortField

import scala.collection.JavaConverters._

case class CHWindowGroupLimitExecTransformer(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
rankLikeFunction: Expression,
limit: Int,
mode: WindowGroupLimitMode,
child: SparkPlan)
extends UnaryTransformSupport {

@transient override lazy val metrics =
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetrics(sparkContext)

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetricsUpdater(metrics)

override def output: Seq[Attribute] = child.output

override def requiredChildDistribution: Seq[Distribution] = mode match {
case Partial => super.requiredChildDistribution
case Final =>
if (partitionSpec.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(partitionSpec) :: Nil
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
if (BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) {
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
} else {
Seq(Nil)
}
}

override def outputOrdering: Seq[SortOrder] = {
if (requiredChildOrdering.forall(_.isEmpty)) {
// The Velox backend `TopNRowNumber` does not require child ordering, because it
// uses hash table to store partition and use priority queue to track of top limit rows.
// Ideally, the output of `TopNRowNumber` is unordered but it is grouped for partition keys.
// To be safe, here we do not propagate the ordering.
// TODO: Make the framework aware of grouped data distribution
Nil
} else {
child.outputOrdering
}
}

override def outputPartitioning: Partitioning = child.outputPartitioning

def getWindowGroupLimitRel(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
input: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
// Partition By Expressions
val partitionsExpressions = partitionSpec
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, attributeSeq = child.output)
.doTransform(args))
.asJava

// Sort By Expressions
val sortFieldList =
orderSpec.map {
order =>
val builder = SortField.newBuilder()
val exprNode = ExpressionConverter
.replaceWithExpressionTransformer(order.child, attributeSeq = child.output)
.doTransform(args)
builder.setExpr(exprNode.toProtobuf)
builder.setDirectionValue(SortExecTransformer.transformSortDirection(order))
builder.build()
}.asJava
if (!validation) {
val windowFunction = rankLikeFunction match {
case _: RowNumber => ExpressionNames.ROW_NUMBER
case _: Rank => ExpressionNames.RANK
case _: DenseRank => ExpressionNames.DENSE_RANK
case _ => throw new GlutenNotSupportException(s"Unknow window function $rankLikeFunction")
}
val parametersStr = new StringBuffer("WindowGroupLimitParameters:")
parametersStr
.append("window_function=")
.append(windowFunction)
.append("\n")
val message = StringValue.newBuilder().setValue(parametersStr.toString).build()
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(message),
null)
RelBuilder.makeWindowGroupLimitRel(
input,
partitionsExpressions,
sortFieldList,
limit,
extensionNode,
context,
operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))

RelBuilder.makeWindowGroupLimitRel(
input,
partitionsExpressions,
sortFieldList,
limit,
extensionNode,
context,
operatorId)
}
}

override protected def doValidateInternal(): ValidationResult = {
if (!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction)) {
return ValidationResult
.failed(s"Found unsupported rank like function: $rankLikeFunction")
}
val substraitContext = new SubstraitContext
val operatorId = substraitContext.nextOperatorId(this.nodeName)

val relNode =
getWindowGroupLimitRel(substraitContext, child.output, operatorId, null, validation = true)

doNativeValidation(substraitContext, relNode)
}

override protected def doTransform(context: SubstraitContext): TransformContext = {
val childCtx = child.asInstanceOf[TransformSupport].transform(context)
val operatorId = context.nextOperatorId(this.nodeName)

val currRel =
getWindowGroupLimitRel(context, child.output, operatorId, childCtx.root, validation = false)
assert(currRel != null, "Window Group Limit Rel should be valid")
TransformContext(childCtx.outputAttributes, output, currRel)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
})

protected def fallbackSets(isAqe: Boolean): Set[Int] = {
if (isSparkVersionGE("3.5")) Set(44, 67, 70) else Set.empty[Int]
Set.empty[Int]
}
protected def excludedTpcdsQueries: Set[String] = Set(
"q66" // inconsistent results
Expand Down

This file was deleted.

33 changes: 0 additions & 33 deletions cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h

This file was deleted.

13 changes: 11 additions & 2 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include <Parser/SubstraitParserUtils.h>
#include <Planner/PlannerActionsVisitor.h>
#include <Processors/Chunk.h>
#include <Processors/Formats/IOutputFormat.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
Expand Down Expand Up @@ -315,6 +316,16 @@ DB::Block BlockUtil::concatenateBlocksMemoryEfficiently(std::vector<DB::Block> &
return out;
}

String BlockUtil::dumpBlock(const DB::Block & block)
{
DB::WriteBufferFromOwnString buf;
auto output_format = QueryContext::globalContext()->getOutputFormat("PrettyCompact", buf, block);
output_format->write(DB::materializeBlock(block));
output_format->flush();
buf.finalize();
return buf.str();
}


size_t PODArrayUtil::adjustMemoryEfficientSize(size_t n)
{
Expand Down Expand Up @@ -890,7 +901,6 @@ extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCom
extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &);
extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &);
extern void registerFunctions(FunctionFactory &);
extern void registerWindowGroupLimitFunctions(AggregateFunctionFactory &);

void registerAllFunctions()
{
Expand All @@ -900,7 +910,6 @@ void registerAllFunctions()
auto & agg_factory = AggregateFunctionFactory::instance();
registerAggregateFunctionsBloomFilter(agg_factory);
registerAggregateFunctionSparkAvg(agg_factory);
registerWindowGroupLimitFunctions(agg_factory);
{
/// register aggregate function combinators from local_engine
auto & factory = AggregateFunctionCombinatorFactory::instance();
Expand Down
Loading

0 comments on commit 7205c3f

Please sign in to comment.