From 47ab88bafe95bac011e8e15d526618fd1d99d1b8 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Mon, 9 Sep 2024 15:55:24 +0800 Subject: [PATCH] implement window group limit --- .../clickhouse/CHSparkPlanExecApi.scala | 16 + .../CHWindowGroupLimitExecTransformer.scala | 187 +++++++++ .../GlutenClickHouseTPCDSAbstractSuite.scala | 2 +- ...enClickHouseTPCHSaltNullParquetSuite.scala | 6 +- .../WindowGroupLimitFunctions.cpp | 92 ----- .../WindowGroupLimitFunctions.h | 33 -- cpp-ch/local-engine/Common/CHUtil.cpp | 4 +- .../Operator/ReplicateRowsStep.cpp | 20 +- .../Operator/WindowGroupLimitStep.cpp | 365 ++++++++++++++++++ .../Operator/WindowGroupLimitStep.h | 51 +++ .../Parser/AdvancedParametersParseUtil.cpp | 31 +- .../Parser/AdvancedParametersParseUtil.h | 9 +- .../Parser/WindowGroupLimitRelParser.cpp | 149 +++---- .../Parser/WindowGroupLimitRelParser.h | 9 +- .../gluten/backendsapi/SparkPlanExecApi.scala | 10 + .../columnar/OffloadSingleNode.scala | 2 +- 16 files changed, 714 insertions(+), 272 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala delete mode 100644 cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp delete mode 100644 cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h create mode 100644 cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp create mode 100644 cpp-ch/local-engine/Operator/WindowGroupLimitStep.h diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index a8996c4d2e834..b7200cba8ee9a 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -53,6 +53,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil} +import org.apache.spark.sql.execution.window._ import org.apache.spark.sql.types.{DecimalType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -883,4 +884,19 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { toScale: Int): DecimalType = { SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType, toScale) } + + override def genWindowGroupLimitTransformer( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + mode: WindowGroupLimitMode, + child: SparkPlan): SparkPlan = + CHWindowGroupLimitExecTransformer( + partitionSpec, + orderSpec, + rankLikeFunction, + limit, + mode, + child) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala new file mode 100644 index 0000000000000..c2648f29ec4cb --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHWindowGroupLimitExecTransformer.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.expression._ +import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} +import org.apache.gluten.extension.ValidationResult +import org.apache.gluten.metrics.MetricsUpdater +import org.apache.gluten.substrait.`type`.TypeBuilder +import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.extensions.ExtensionBuilder +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.window.{Final, Partial, WindowGroupLimitMode} + +import com.google.protobuf.StringValue +import io.substrait.proto.SortField + +import scala.collection.JavaConverters._ + +case class CHWindowGroupLimitExecTransformer( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + mode: WindowGroupLimitMode, + child: SparkPlan) + extends UnaryTransformSupport { + + @transient override lazy val metrics = + BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetrics(sparkContext) + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) + + override def metricsUpdater(): MetricsUpdater = + BackendsApiManager.getMetricsApiInstance.genWindowTransformerMetricsUpdater(metrics) + + override def output: Seq[Attribute] = child.output + + override def requiredChildDistribution: Seq[Distribution] = mode match { + case Partial => super.requiredChildDistribution + case Final => + if (partitionSpec.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(partitionSpec) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + if (BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) { + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) + } else { + Seq(Nil) + } + } + + override def outputOrdering: Seq[SortOrder] = { + if (requiredChildOrdering.forall(_.isEmpty)) { + // The Velox backend `TopNRowNumber` does not require child ordering, because it + // uses hash table to store partition and use priority queue to track of top limit rows. + // Ideally, the output of `TopNRowNumber` is unordered but it is grouped for partition keys. + // To be safe, here we do not propagate the ordering. + // TODO: Make the framework aware of grouped data distribution + Nil + } else { + child.outputOrdering + } + } + + override def outputPartitioning: Partitioning = child.outputPartitioning + + def getWindowGroupLimitRel( + context: SubstraitContext, + originalInputAttributes: Seq[Attribute], + operatorId: Long, + input: RelNode, + validation: Boolean): RelNode = { + val args = context.registeredFunction + // Partition By Expressions + val partitionsExpressions = partitionSpec + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, attributeSeq = child.output) + .doTransform(args)) + .asJava + + // Sort By Expressions + val sortFieldList = + orderSpec.map { + order => + val builder = SortField.newBuilder() + val exprNode = ExpressionConverter + .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) + .doTransform(args) + builder.setExpr(exprNode.toProtobuf) + builder.setDirectionValue(SortExecTransformer.transformSortDirection(order)) + builder.build() + }.asJava + if (!validation) { + val windowFunction = rankLikeFunction match { + case _: RowNumber => ExpressionNames.ROW_NUMBER + case _: Rank => ExpressionNames.RANK + case _: DenseRank => ExpressionNames.DENSE_RANK + case _ => throw new GlutenNotSupportException(s"Unknow window function $rankLikeFunction") + } + val parametersStr = new StringBuffer("WindowGroupLimitParameters:") + parametersStr + .append("window_function=") + .append(windowFunction) + .append("\n") + val message = StringValue.newBuilder().setValue(parametersStr.toString).build() + val extensionNode = ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage(message), + null) + RelBuilder.makeWindowGroupLimitRel( + input, + partitionsExpressions, + sortFieldList, + limit, + extensionNode, + context, + operatorId) + } else { + // Use a extension node to send the input types through Substrait plan for validation. + val inputTypeNodeList = originalInputAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava + val extensionNode = ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage( + TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) + + RelBuilder.makeWindowGroupLimitRel( + input, + partitionsExpressions, + sortFieldList, + limit, + extensionNode, + context, + operatorId) + } + } + + override protected def doValidateInternal(): ValidationResult = { + if (!BackendsApiManager.getSettings.supportWindowGroupLimitExec(rankLikeFunction)) { + return ValidationResult + .failed(s"Found unsupported rank like function: $rankLikeFunction") + } + val substraitContext = new SubstraitContext + val operatorId = substraitContext.nextOperatorId(this.nodeName) + + val relNode = + getWindowGroupLimitRel(substraitContext, child.output, operatorId, null, validation = true) + + doNativeValidation(substraitContext, relNode) + } + + override protected def doTransform(context: SubstraitContext): TransformContext = { + val childCtx = child.asInstanceOf[TransformSupport].transform(context) + val operatorId = context.nextOperatorId(this.nodeName) + + val currRel = + getWindowGroupLimitRel(context, child.output, operatorId, childCtx.root, validation = false) + assert(currRel != null, "Window Group Limit Rel should be valid") + TransformContext(childCtx.outputAttributes, output, currRel) + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala index 03b26fa985ea9..abb7d27ffe92e 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala @@ -62,7 +62,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite }) protected def fallbackSets(isAqe: Boolean): Set[Int] = { - if (isSparkVersionGE("3.5")) Set(44, 67, 70) else Set.empty[Int] + Set.empty[Int] } protected def excludedTpcdsQueries: Set[String] = Set( "q66" // inconsistent results diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 49697872e8aec..e7ce6eeda1aa4 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -1855,7 +1855,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | ) t1 |) t2 where rank = 1 """.stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }, isSparkVersionLE("3.3")) + compareResultsAgainstVanillaSpark(sql, true, { _ => }) } test("GLUTEN-1874 not null in both streams") { @@ -1873,7 +1873,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | ) t1 |) t2 where rank = 1 """.stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }, isSparkVersionLE("3.3")) + compareResultsAgainstVanillaSpark(sql, true, { _ => }) } test("GLUTEN-2095: test cast(string as binary)") { @@ -2456,7 +2456,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | ) t1 |) t2 where rank = 1 order by p_partkey limit 100 |""".stripMargin - runQueryAndCompare(sql, noFallBack = isSparkVersionLE("3.3"))({ _ => }) + runQueryAndCompare(sql, noFallBack = true)({ _ => }) } test("GLUTEN-4190: crush on flattening a const null column") { diff --git a/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp b/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp deleted file mode 100644 index 57232b7ecf59f..0000000000000 --- a/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.cpp +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include -#include -#include -#include -#include -#include - -namespace DB::ErrorCodes -{ -extern const int BAD_ARGUMENTS; -} - -namespace local_engine -{ -WindowFunctionTopRowNumber::WindowFunctionTopRowNumber(const String name, const DB::DataTypes & arg_types, const DB::Array & parameters_) - : DB::WindowFunction(name, arg_types, parameters_, std::make_shared()) -{ - if (parameters.size() != 1) - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "{} needs a limit parameter", name); - limit = parameters[0].safeGet(); - LOG_ERROR(getLogger("WindowFunctionTopRowNumber"), "xxx {} limit: {}", name, limit); -} - -void WindowFunctionTopRowNumber::windowInsertResultInto(const DB::WindowTransform * transform, size_t function_index) const -{ - LOG_ERROR( - getLogger("WindowFunctionTopRowNumber"), - "xxx current row number: {}, current_row: {}@{}, partition_ended: {}", - transform->current_row_number, - transform->current_row.block, - transform->current_row.row, - transform->partition_ended); - /// If the rank value is larger then limit, and current block only contains rows which are all belong to one partition. - /// We cant drop this block directly. - if (!transform->partition_ended && !transform->current_row.row && transform->current_row_number > limit) - { - /// It's safe to make it mutable here. but it's still too dangerous, it may be changed in the future and make it unsafe. - auto * mutable_transform = const_cast(transform); - DB::WindowTransformBlock & current_block = mutable_transform->blockAt(mutable_transform->current_row); - current_block.rows = 0; - auto clear_columns = [](DB::Columns & cols) - { - DB::Columns new_cols; - for (const auto & col : cols) - { - new_cols.push_back(std::move(col->cloneEmpty())); - } - cols = new_cols; - }; - clear_columns(current_block.original_input_columns); - clear_columns(current_block.input_columns); - clear_columns(current_block.casted_columns); - mutable_transform->current_row.block += 1; - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "{} is not implemented", name); - } - else - { - auto & to_col = *transform->blockAt(transform->current_row).output_columns[function_index]; - assert_cast(to_col).getData().push_back(transform->current_row_number); - } -} - -void registerWindowGroupLimitFunctions(DB::AggregateFunctionFactory & factory) -{ - const DB::AggregateFunctionProperties properties - = {.returns_default_when_only_null = true, .is_order_dependent = true, .is_window_function = true}; - factory.registerFunction( - "top_row_number", - {[](const String & name, const DB::DataTypes & args_type, const DB::Array & parameters, const DB::Settings *) - { return std::make_shared(name, args_type, parameters); }, - properties}, - DB::AggregateFunctionFactory::Case::Insensitive); -} -} diff --git a/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h b/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h deleted file mode 100644 index 6c5cc19458d30..0000000000000 --- a/cpp-ch/local-engine/AggregateFunctions/WindowGroupLimitFunctions.h +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -namespace local_engine -{ -class WindowFunctionTopRowNumber : public DB::WindowFunction -{ -public: - explicit WindowFunctionTopRowNumber(const String name, const DB::DataTypes & arg_types_, const DB::Array & parameters_); - ~WindowFunctionTopRowNumber() override = default; - - void windowInsertResultInto(const DB::WindowTransform * transform, size_t function_index) const override; - bool allocatesMemoryInArena() const override { return false; } - -private: - size_t limit = 0; -}; -} diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index dd284eac4032a..cbfa3e1900410 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -53,6 +53,7 @@ #include #include #include +#include #include #include #include @@ -315,7 +316,6 @@ DB::Block BlockUtil::concatenateBlocksMemoryEfficiently(std::vector & return out; } - size_t PODArrayUtil::adjustMemoryEfficientSize(size_t n) { /// According to definition of DEFUALT_BLOCK_SIZE @@ -890,7 +890,6 @@ extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCom extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &); extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &); extern void registerFunctions(FunctionFactory &); -extern void registerWindowGroupLimitFunctions(AggregateFunctionFactory &); void registerAllFunctions() { @@ -900,7 +899,6 @@ void registerAllFunctions() auto & agg_factory = AggregateFunctionFactory::instance(); registerAggregateFunctionsBloomFilter(agg_factory); registerAggregateFunctionSparkAvg(agg_factory); - registerWindowGroupLimitFunctions(agg_factory); { /// register aggregate function combinators from local_engine auto & factory = AggregateFunctionCombinatorFactory::instance(); diff --git a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp index f2d4bc8a865dc..ecb027c18f0a0 100644 --- a/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp +++ b/cpp-ch/local-engine/Operator/ReplicateRowsStep.cpp @@ -32,16 +32,14 @@ namespace local_engine { static DB::ITransformingStep::Traits getTraits() { - return DB::ITransformingStep::Traits - { + return DB::ITransformingStep::Traits{ { .preserves_number_of_streams = true, .preserves_sorting = false, }, { .preserves_number_of_rows = false, - } - }; + }}; } ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream) @@ -49,7 +47,7 @@ ReplicateRowsStep::ReplicateRowsStep(const DB::DataStream & input_stream) { } -DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input) +DB::Block ReplicateRowsStep::transformHeader(const DB::Block & input) { DB::Block output; for (int i = 1; i < input.columns(); i++) @@ -59,15 +57,9 @@ DB::Block ReplicateRowsStep::transformHeader(const DB::Block& input) return output; } -void ReplicateRowsStep::transformPipeline( - DB::QueryPipelineBuilder & pipeline, - const DB::BuildQueryPipelineSettings & /*settings*/) +void ReplicateRowsStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/) { - pipeline.addSimpleTransform( - [&](const DB::Block & header) - { - return std::make_shared(header); - }); + pipeline.addSimpleTransform([&](const DB::Block & header) { return std::make_shared(header); }); } void ReplicateRowsStep::updateOutputStream() @@ -105,4 +97,4 @@ void ReplicateRowsTransform::transform(DB::Chunk & chunk) chunk.setColumns(std::move(mutable_columns), total_rows); } -} \ No newline at end of file +} diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp new file mode 100644 index 0000000000000..af04ef5790289 --- /dev/null +++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.cpp @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "WindowGroupLimitStep.h" +#include +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int LOGICAL_ERROR; +} + +namespace local_engine +{ + +enum class WindowGroupLimitFunction +{ + RowNumber, + Rank, + DenseRank +}; + + +template +class WindowGroupLimitTransform : public DB::IProcessor +{ +public: + using Status = DB::IProcessor::Status; + explicit WindowGroupLimitTransform( + const DB::Block & header_, const std::vector & partition_columns_, const std::vector & sort_columns_, size_t limit_) + : DB::IProcessor({header_}, {header_}) + , header(header_) + , partition_columns(partition_columns_) + , sort_columns(sort_columns_) + , limit(limit_) + + { + } + ~WindowGroupLimitTransform() override = default; + String getName() const override { return "WindowGroupLimitTransform"; } + + Status prepare() override + { + auto & output_port = outputs.front(); + auto & input_port = inputs.front(); + if (output_port.isFinished()) + { + input_port.close(); + return Status::Finished; + } + + if (has_output) + { + if (output_port.canPush()) + { + output_port.push(std::move(output_chunk)); + has_output = false; + } + return Status::PortFull; + } + + if (has_input) + return Status::Ready; + + if (input_port.isFinished()) + { + output_port.finish(); + return Status::Finished; + } + input_port.setNeeded(); + if (!input_port.hasData()) + return Status::NeedData; + input_chunk = input_port.pull(true); + has_input = true; + return Status::Ready; + } + + void work() override + { + if (!has_input) [[unlikely]] + { + return; + } + DB::Block block = header.cloneWithColumns(input_chunk.getColumns()); + size_t partition_start_row = 0; + size_t chunk_rows = input_chunk.getNumRows(); + while (partition_start_row < chunk_rows) + { + auto next_partition_start_row = advanceNextPartition(input_chunk, partition_start_row); + iteratePartition(input_chunk, partition_start_row, next_partition_start_row); + partition_start_row = next_partition_start_row; + // corner case, the partition end row is the last row of chunk. + if (partition_start_row < chunk_rows) + { + current_row_rank_value = 1; + if constexpr (function == WindowGroupLimitFunction::Rank) + current_peer_group_rows = 0; + partition_start_row_columns = extractOneRowColumns(input_chunk, partition_start_row); + } + } + + if (!output_columns.empty() && output_columns[0]->size() > 0) + { + auto rows = output_columns[0]->size(); + output_chunk = DB::Chunk(std::move(output_columns), rows); + output_columns.clear(); + has_output = true; + } + has_input = false; + } + +private: + DB::Block header; + // Which columns are used as the partition keys + std::vector partition_columns; + // which columns are used as the order by keys, excluding partition columns. + std::vector sort_columns; + // Limitations for each partition. + size_t limit = 0; + + bool has_input = false; + DB::Chunk input_chunk; + bool has_output = false; + DB::MutableColumns output_columns; + DB::Chunk output_chunk; + + // We don't have window frame here. in fact all of frame are (unbounded preceding, current row] + // the start value is 1 + size_t current_row_rank_value = 1; + // rank need this to record how many rows in current peer group. + // A peer group in a partition is defined as the rows have the same value on the sort columns. + size_t current_peer_group_rows = 0; + + DB::Columns partition_start_row_columns; + DB::Columns peer_group_start_row_columns; + + + size_t advanceNextPartition(const DB::Chunk & chunk, size_t start_offset) + { + if (partition_start_row_columns.empty()) + partition_start_row_columns = extractOneRowColumns(chunk, start_offset); + + size_t max_row = chunk.getNumRows(); + for (size_t i = start_offset; i < max_row; ++i) + { + if (!isRowEqual(partition_columns, partition_start_row_columns, 0, chunk.getColumns(), i)) + { + return i; + } + } + return max_row; + } + + static DB::Columns extractOneRowColumns(const DB::Chunk & chunk, size_t offset) + { + DB::Columns row; + for (const auto & col : chunk.getColumns()) + { + auto new_col = col->cloneEmpty(); + new_col->insertFrom(*col, offset); + row.push_back(std::move(new_col)); + } + return row; + } + + static bool isRowEqual( + const std::vector & fields, const DB::Columns & left_cols, size_t loffset, const DB::Columns & right_cols, size_t roffset) + { + for (size_t i = 0; i < fields.size(); ++i) + { + const auto & field = fields[i]; + /// don't care about nan_direction_hint + if (left_cols[field]->compareAt(loffset, roffset, *right_cols[field], 1)) + return false; + } + return true; + } + + void iteratePartition(const DB::Chunk & chunk, size_t start_offset, size_t end_offset) + { + // Skip the rest rows int this partition. + if (current_row_rank_value > limit) + return; + + + size_t chunk_rows = chunk.getNumRows(); + auto has_peer_group_ended = [&](size_t offset, size_t partition_end_offset, size_t chunk_rows_) + { return offset < partition_end_offset || end_offset < chunk_rows_; }; + auto try_end_peer_group + = [&](size_t peer_group_start_offset, size_t next_peer_group_start_offset, size_t partition_end_offset, size_t chunk_rows_) + { + if constexpr (function == WindowGroupLimitFunction::Rank) + { + current_peer_group_rows += next_peer_group_start_offset - peer_group_start_offset; + if (has_peer_group_ended(next_peer_group_start_offset, partition_end_offset, chunk_rows_)) + { + current_row_rank_value += current_peer_group_rows; + current_peer_group_rows = 0; + peer_group_start_row_columns = extractOneRowColumns(chunk, next_peer_group_start_offset); + } + } + else if constexpr (function == WindowGroupLimitFunction::DenseRank) + { + if (has_peer_group_ended(next_peer_group_start_offset, partition_end_offset, chunk_rows_)) + { + current_row_rank_value += 1; + peer_group_start_row_columns = extractOneRowColumns(chunk, next_peer_group_start_offset); + } + } + }; + + // This is a corner case. prev partition's last row is the last row of a chunk. + if (start_offset >= end_offset) + { + assert(!start_offset); + try_end_peer_group(start_offset, end_offset, end_offset, chunk_rows); + return; + } + + // row_number is simple + if constexpr (function == WindowGroupLimitFunction::RowNumber) + { + size_t rows = end_offset - start_offset; + size_t limit_remained = limit - current_row_rank_value + 1; + rows = rows > limit_remained ? limit_remained : rows; + insertResultValue(chunk, start_offset, rows); + current_row_rank_value += rows; + } + else + { + size_t peer_group_start_offset = start_offset; + while (peer_group_start_offset < end_offset && current_row_rank_value <= limit) + { + auto next_peer_group_start_offset = advanceNextPeerGroup(chunk, peer_group_start_offset, end_offset); + + insertResultValue(chunk, peer_group_start_offset, next_peer_group_start_offset - peer_group_start_offset); + try_end_peer_group(peer_group_start_offset, next_peer_group_start_offset, end_offset, chunk_rows); + peer_group_start_offset = next_peer_group_start_offset; + } + } + } + void insertResultValue(const DB::Chunk & chunk, size_t start_offset, size_t rows) + { + if (!rows) + return; + if (output_columns.empty()) + { + for (const auto & col : chunk.getColumns()) + { + output_columns.push_back(col->cloneEmpty()); + } + } + size_t i = 0; + for (const auto & col : chunk.getColumns()) + { + output_columns[i]->insertRangeFrom(*col, start_offset, rows); + i += 1; + } + } + size_t advanceNextPeerGroup(const DB::Chunk & chunk, size_t start_offset, size_t partition_end_offset) + { + if (peer_group_start_row_columns.empty()) + peer_group_start_row_columns = extractOneRowColumns(chunk, start_offset); + for (size_t i = start_offset; i < partition_end_offset; ++i) + { + if (!isRowEqual(sort_columns, peer_group_start_row_columns, 0, chunk.getColumns(), i)) + { + return i; + } + } + return partition_end_offset; + } +}; + +static DB::ITransformingStep::Traits getTraits() +{ + return DB::ITransformingStep::Traits{ + { + .preserves_number_of_streams = false, + .preserves_sorting = true, + }, + { + .preserves_number_of_rows = false, + }}; +} + +WindowGroupLimitStep::WindowGroupLimitStep( + const DB::DataStream & input_stream_, + const String & function_name_, + const std::vector partition_columns_, + const std::vector sort_columns_, + size_t limit_) + : DB::ITransformingStep(input_stream_, input_stream_.header, getTraits()) + , function_name(function_name_) + , partition_columns(partition_columns_) + , sort_columns(sort_columns_) + , limit(limit_) +{ +} + +void WindowGroupLimitStep::describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const +{ + if (!processors.empty()) + DB::IQueryPlanStep::describePipeline(processors, settings); +} + +void WindowGroupLimitStep::updateOutputStream() +{ + output_stream = createOutputStream(input_streams.front(), input_streams.front().header, getDataStreamTraits()); +} + + +void WindowGroupLimitStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & /*settings*/) +{ + if (function_name == "row_number") + { + pipeline.addSimpleTransform( + [&](const DB::Block & header) + { + return std::make_shared>( + header, partition_columns, sort_columns, limit); + }); + } + else if (function_name == "rank") + { + pipeline.addSimpleTransform( + [&](const DB::Block & header) { + return std::make_shared>( + header, partition_columns, sort_columns, limit); + }); + } + else if (function_name == "dense_rank") + { + pipeline.addSimpleTransform( + [&](const DB::Block & header) + { + return std::make_shared>( + header, partition_columns, sort_columns, limit); + }); + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport function {} in WindowGroupLimit", function_name); + } +} +} diff --git a/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h new file mode 100644 index 0000000000000..bbbbf42abc55e --- /dev/null +++ b/cpp-ch/local-engine/Operator/WindowGroupLimitStep.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +namespace local_engine +{ +class WindowGroupLimitStep : public DB::ITransformingStep +{ +public: + explicit WindowGroupLimitStep( + const DB::DataStream & input_stream_, + const String & function_name_, + const std::vector partition_columns_, + const std::vector sort_columns_, + size_t limit_); + ~WindowGroupLimitStep() override = default; + + String getName() const override { return "WindowGroupLimitStep"; } + + void transformPipeline(DB::QueryPipelineBuilder & pipeline, const DB::BuildQueryPipelineSettings & settings) override; + void describePipeline(DB::IQueryPlanStep::FormatSettings & settings) const override; + void updateOutputStream() override; + +private: + // window function name, one of row_number, rank and dense_rank + String function_name; + std::vector partition_columns; + std::vector sort_columns; + size_t limit; +}; + +} diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp index 42d4f4d4d8cdf..cc7738a15aaa8 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp @@ -14,25 +14,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include "AdvancedParametersParseUtil.h" #include #include -#include -#include +#include #include +#include #include + namespace DB::ErrorCodes { - extern const int BAD_ARGUMENTS; +extern const int BAD_ARGUMENTS; } namespace local_engine { -template +template void tryAssign(const std::unordered_map & kvs, const String & key, T & v); -template<> +template <> void tryAssign(const std::unordered_map & kvs, const String & key, String & v) { auto it = kvs.find(key); @@ -40,7 +41,7 @@ void tryAssign(const std::unordered_map & kvs, const Str v = it->second; } -template<> +template <> void tryAssign(const std::unordered_map & kvs, const String & key, bool & v) { auto it = kvs.find(key); @@ -57,7 +58,7 @@ void tryAssign(const std::unordered_map & kvs, const Strin } } -template<> +template <> void tryAssign(const std::unordered_map & kvs, const String & key, Int64 & v) { auto it = kvs.find(key); @@ -94,9 +95,9 @@ void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf) std::unordered_map> convertToKVs(const String & advance) { std::unordered_map> res; - std::unordered_map *kvs; + std::unordered_map * kvs; DB::ReadBufferFromString in(advance); - while(!in.eof()) + while (!in.eof()) { String key; readStringUntilCharsInto<'=', '\n', ':'>(key, in); @@ -146,5 +147,13 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance) tryAssign(kvs, "numPartitions", info.partitions_num); return info; } -} +WindowGroupOptimizationInfo WindowGroupOptimizationInfo::parse(const String & advance) +{ + WindowGroupOptimizationInfo info; + auto kkvs = convertToKVs(advance); + auto & kvs = kkvs["WindowGroupLimitParameters"]; + tryAssign(kvs, "window_function", info.window_function); + return info; +} +} diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h index 5f6fe6d256e3c..fc478db33bfd2 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h @@ -16,10 +16,10 @@ */ #pragma once #include +#include namespace local_engine { - std::unordered_map> convertToKVs(const String & advance); @@ -38,5 +38,10 @@ struct JoinOptimizationInfo static JoinOptimizationInfo parse(const String & advance); }; -} +struct WindowGroupOptimizationInfo +{ + String window_function; + static WindowGroupOptimizationInfo parse(const String & advnace); +}; +} diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp index 153918850ff9c..f6c10386f4051 100644 --- a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp +++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.cpp @@ -15,127 +15,61 @@ * limitations under the License. */ +#include "WindowGroupLimitRelParser.h" #include +#include +#include #include #include #include -#include #include +#include +#include "AdvancedParametersParseUtil.h" namespace DB::ErrorCodes { extern const int BAD_ARGUMENTS; } -const static String FUNCTION_ROW_NUM = "top_row_number"; -const static String FUNCTION_RANK = "top_rank"; -const static String FUNCTION_DENSE_RANK = "top_dense_rank"; - namespace local_engine { WindowGroupLimitRelParser::WindowGroupLimitRelParser(SerializedPlanParser * plan_parser_) : RelParser(plan_parser_) { - LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx new parrser"); } + DB::QueryPlanPtr WindowGroupLimitRelParser::parse(DB::QueryPlanPtr current_plan_, const substrait::Rel & rel, std::list & rel_stack_) { const auto win_rel_def = rel.windowgrouplimit(); - current_plan = std::move(current_plan_); - - DB::Block output_header = current_plan->getCurrentDataStream().header; - - window_function_name = FUNCTION_ROW_NUM; - LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx input header: {}", current_plan->getCurrentDataStream().header.dumpStructure()); + google::protobuf::StringValue optimize_info_str; + optimize_info_str.ParseFromString(win_rel_def.advanced_extension().optimization().value()); + auto optimization_info = WindowGroupOptimizationInfo::parse(optimize_info_str.value()); + window_function_name = optimization_info.window_function; - /// Only one window function in one window group limit - auto win_desc = buildWindowDescription(win_rel_def); + current_plan = std::move(current_plan_); - auto win_step = std::make_unique(current_plan->getCurrentDataStream(), win_desc, win_desc.window_functions, false); - win_step->setStepDescription("Window Group Limit " + win_desc.window_name); - steps.emplace_back(win_step.get()); - current_plan->addStep(std::move(win_step)); + auto partition_fields = parsePartitoinFields(win_rel_def.partition_expressions()); + auto sort_fields = parseSortFields(win_rel_def.sorts()); + size_t limit = static_cast(win_rel_def.limit()); - /// remove the window function result column which is not needed in later steps - DB::ActionsDAG post_project_actions_dag = DB::ActionsDAG::makeConvertingActions( - current_plan->getCurrentDataStream().header.getColumnsWithTypeAndName(), - output_header.getColumnsWithTypeAndName(), - DB::ActionsDAG::MatchColumnsMode::Name); - auto post_project_step - = std::make_unique(current_plan->getCurrentDataStream(), std::move(post_project_actions_dag)); - post_project_step->setStepDescription("Window group limit: drop window function result column"); - steps.emplace_back(post_project_step.get()); - current_plan->addStep(std::move(post_project_step)); + auto window_group_limit_step = std::make_unique( + current_plan->getCurrentDataStream(), window_function_name, partition_fields, sort_fields, limit); + window_group_limit_step->setStepDescription("Window group limit"); + steps.emplace_back(window_group_limit_step.get()); + current_plan->addStep(std::move(window_group_limit_step)); - LOG_ERROR(getLogger("WindowGroupLimitRelParser"), "xxx output header: {}", current_plan->getCurrentDataStream().header.dumpStructure()); return std::move(current_plan); } -DB::WindowFrame WindowGroupLimitRelParser::buildWindowFrame(const String & function_name) +std::vector +WindowGroupLimitRelParser::parsePartitoinFields(const google::protobuf::RepeatedPtrField & expressions) { - // We only need first rows, so let the begin type is unbounded is OK - DB::WindowFrame frame; - if (function_name == FUNCTION_ROW_NUM) - { - frame.type = DB::WindowFrame::FrameType::ROWS; - frame.begin_type = DB::WindowFrame::BoundaryType::Unbounded; - frame.begin_offset = 0; - frame.begin_preceding = true; - frame.end_type = DB::WindowFrame::BoundaryType::Current; - frame.end_offset = 0; - frame.end_preceding = true; - } - else if (function_name == FUNCTION_RANK || function_name == FUNCTION_DENSE_RANK) - { - // rank and dense_rank can only work on range mode - frame.type = DB::WindowFrame::FrameType::RANGE; - frame.begin_type = DB::WindowFrame::BoundaryType::Unbounded; - frame.begin_offset = 0; - frame.begin_preceding = true; - frame.end_type = DB::WindowFrame::BoundaryType::Current; - frame.end_offset = 0; - frame.end_preceding = true; - } - else - { - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown function {} for window group limit", function_name); - } - - return frame; -} - -DB::WindowDescription WindowGroupLimitRelParser::buildWindowDescription(const substrait::WindowGroupLimitRel & win_rel_def) -{ - DB::WindowDescription win_desc; - win_desc.frame = buildWindowFrame(window_function_name); - win_desc.partition_by = parsePartitionBy(win_rel_def.partition_expressions()); - win_desc.order_by = SortRelParser::parseSortDescription(win_rel_def.sorts(), current_plan->getCurrentDataStream().header); - win_desc.full_sort_description = win_desc.partition_by; - win_desc.full_sort_description.insert(win_desc.full_sort_description.end(), win_desc.order_by.begin(), win_desc.order_by.end()); - - DB::WriteBufferFromOwnString ss; - ss << "partition by " << DB::dumpSortDescription(win_desc.partition_by); - ss << "order by " << DB::dumpSortDescription(win_desc.order_by); - ss << win_desc.frame.toString(); - win_desc.window_name = ss.str(); - - win_desc.window_functions.emplace_back(buildWindowFunctionDescription(window_function_name, static_cast(win_rel_def.limit()))); - - return win_desc; -} - -DB::SortDescription -WindowGroupLimitRelParser::parsePartitionBy(const google::protobuf::RepeatedPtrField & expressions) -{ - DB::Block header = current_plan->getCurrentDataStream().header; - DB::SortDescription sort_desc; + std::vector fields; for (const auto & expr : expressions) { if (expr.has_selection()) { - auto pos = expr.selection().direct_reference().struct_field().field(); - auto col_name = header.getByPosition(pos).name; - sort_desc.push_back(DB::SortColumnDescription(col_name, 1, 1)); + fields.push_back(static_cast(expr.selection().direct_reference().struct_field().field())); } else if (expr.has_literal()) { @@ -143,28 +77,33 @@ WindowGroupLimitRelParser::parsePartitionBy(const google::protobuf::RepeatedPtrF } else { - throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow partition argument: {}", expr.DebugString()); + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow expression: {}", expr.DebugString()); } } - return sort_desc; + return fields; } -DB::WindowFunctionDescription WindowGroupLimitRelParser::buildWindowFunctionDescription(const String & function_name, size_t limit) +std::vector WindowGroupLimitRelParser::parseSortFields(const google::protobuf::RepeatedPtrField & sort_fields) { - DB::WindowFunctionDescription desc; - desc.column_name = function_name; - desc.function_node = nullptr; - DB::AggregateFunctionProperties func_properties; - DB::Names func_args; - DB::DataTypes func_args_types; - DB::Array func_params; - func_params.push_back(limit); - auto func_ptr = RelParser::getAggregateFunction(function_name, func_args_types, func_properties, func_params); - desc.argument_names = func_args; - desc.argument_types = func_args_types; - desc.aggregate_function = func_ptr; - return desc; + std::vector fields; + for (const auto sort_field : sort_fields) + { + if (sort_field.expr().has_literal()) + { + continue; + } + else if (sort_field.expr().has_selection()) + { + fields.push_back(static_cast(sort_field.expr().selection().direct_reference().struct_field().field())); + } + else + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown expression: {}", sort_field.expr().DebugString()); + } + } + return fields; } + void registerWindowGroupLimitRelParser(RelParserFactory & factory) { auto builder = [](SerializedPlanParser * plan_parser) { return std::make_shared(plan_parser); }; diff --git a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h index 6b7d3bbf33e5a..c9c503ed4745f 100644 --- a/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h +++ b/cpp-ch/local-engine/Parser/WindowGroupLimitRelParser.h @@ -46,12 +46,7 @@ class WindowGroupLimitRelParser : public RelParser DB::QueryPlanPtr current_plan; String window_function_name; - DB::WindowDescription buildWindowDescription(const substrait::WindowGroupLimitRel & win_rel_def); - /// There is only one type of window frame at present. - static DB::WindowFrame buildWindowFrame(const String & function_name); - - DB::SortDescription parsePartitionBy(const google::protobuf::RepeatedPtrField & expressions); - - static DB::WindowFunctionDescription buildWindowFunctionDescription(const String & function_name, size_t limit); + std::vector parsePartitoinFields(const google::protobuf::RepeatedPtrField & expressions); + std::vector parseSortFields(const google::protobuf::RepeatedPtrField & sort_fields); }; } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index a55926d76d12d..dd4150806cfc9 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec +import org.apache.spark.sql.execution.window._ import org.apache.spark.sql.hive.{HiveTableScanExecTransformer, HiveUDFTransformer} import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -678,6 +679,15 @@ trait SparkPlanExecApi { } } + def genWindowGroupLimitTransformer( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + rankLikeFunction: Expression, + limit: Int, + mode: WindowGroupLimitMode, + child: SparkPlan): SparkPlan = + WindowGroupLimitExecTransformer(partitionSpec, orderSpec, rankLikeFunction, limit, mode, child) + def genHiveUDFTransformer( expr: Expression, attributeSeq: Seq[Attribute]): ExpressionTransformer = { diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 6047789e6abe9..5440481f8ddd9 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -358,7 +358,7 @@ object OffloadOthers { val windowGroupLimitPlan = SparkShimLoader.getSparkShims .getWindowGroupLimitExecShim(plan) .asInstanceOf[WindowGroupLimitExecShim] - WindowGroupLimitExecTransformer( + BackendsApiManager.getSparkPlanExecApiInstance.genWindowGroupLimitTransformer( windowGroupLimitPlan.partitionSpec, windowGroupLimitPlan.orderSpec, windowGroupLimitPlan.rankLikeFunction,