From 4207596a6d241a06b578ad7e1c7531b5cc311493 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Wed, 10 Jul 2024 10:01:39 +0800 Subject: [PATCH] support percent_rank --- .../gluten/backendsapi/clickhouse/CHBackend.scala | 3 ++- .../backendsapi/clickhouse/CHSparkPlanExecApi.scala | 2 +- .../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 11 +++++++++++ cpp-ch/local-engine/Parser/WindowRelParser.cpp | 1 + .../CommonAggregateFunctionParser.cpp | 1 + 5 files changed, 16 insertions(+), 2 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index d369b8c1626f..3b8499ac8b1f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -237,7 +237,8 @@ object CHBackendSettings extends BackendSettingsApi with Logging { } wExpression.windowFunction match { - case _: RowNumber | _: AggregateExpression | _: Rank | _: DenseRank | _: NTile => + case _: RowNumber | _: AggregateExpression | _: Rank | _: DenseRank | _: PercentRank | + _: NTile => allSupported = allSupported case l: Lag => checkLagOrLead(l.third) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index f5feade886b9..7665216ce87e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -704,7 +704,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { val columnName = s"${aliasExpr.name}_${aliasExpr.exprId.id}" val wExpression = aliasExpr.child.asInstanceOf[WindowExpression] wExpression.windowFunction match { - case wf @ (RowNumber() | Rank(_) | DenseRank(_)) => + case wf @ (RowNumber() | Rank(_) | DenseRank(_) | PercentRank(_)) => val aggWindowFunc = wf.asInstanceOf[AggregateWindowFunction] val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame] val windowFunctionNode = ExpressionBuilder.makeWindowFunction( diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index b0d3e1bdb866..b7bf818a32ce 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -978,6 +978,17 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr compareResultsAgainstVanillaSpark(sql, true, { _ => }) } + test("window percent_rank") { + val sql = + """ + |select n_regionkey, n_nationkey, + | percent_rank(n_nationkey) OVER (PARTITION BY n_regionkey ORDER BY n_nationkey) as n_rank + |from nation + |order by n_regionkey, n_nationkey + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } + test("window ntile") { val sql = """ diff --git a/cpp-ch/local-engine/Parser/WindowRelParser.cpp b/cpp-ch/local-engine/Parser/WindowRelParser.cpp index 4125879b5ec7..2317c8098b85 100644 --- a/cpp-ch/local-engine/Parser/WindowRelParser.cpp +++ b/cpp-ch/local-engine/Parser/WindowRelParser.cpp @@ -172,6 +172,7 @@ WindowRelParser::parseWindowFrameType(const std::string & function_name, const s static const std::unordered_map special_function_frame_type = { {"rank", substrait::RANGE}, {"dense_rank", substrait::RANGE}, + {"percent_rank", substrait::RANGE} }; substrait::WindowType frame_type; diff --git a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp index e7d6e1b9bd73..d88885a312b2 100644 --- a/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/aggregate_function_parser/CommonAggregateFunctionParser.cpp @@ -40,6 +40,7 @@ REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(FirstIgnoreNull, first_ignore_null, fi REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Last, last, last_value_respect_nulls) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(LastIgnoreNull, last_ignore_null, last_value) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(DenseRank, dense_rank, dense_rank) +REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(PercentRank, percent_rank, percent_rank) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Rank, rank, rank) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(RowNumber, row_number, row_number) REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(CountDistinct, count_distinct, uniqExact)