diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index c79d0aaee800..1587b9ea3488 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -26,7 +26,7 @@ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat._ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, DenseRank, Lag, Lead, NamedExpression, Rank, RowNumber} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, DenseRank, Expression, Lag, Lead, Literal, NamedExpression, Rank, RowNumber} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.SparkPlan @@ -225,10 +225,25 @@ object CHBackendSettings extends BackendSettingsApi with Logging { func => { val aliasExpr = func.asInstanceOf[Alias] val wExpression = WindowFunctionsBuilder.extractWindowExpression(aliasExpr.child) + + def checkLagOrLead(third: Expression): Unit = { + third match { + case _: Literal => + allSupported = allSupported + case _ => + logInfo("Not support lag/lead function with default value not literal null") + allSupported = false + break + } + } + wExpression.windowFunction match { - case _: RowNumber | _: AggregateExpression | _: Rank | _: Lead | _: Lag | - _: DenseRank => + case _: RowNumber | _: AggregateExpression | _: Rank | _: DenseRank => allSupported = allSupported + case l: Lag => + checkLagOrLead(l.third) + case l: Lead => + checkLagOrLead(l.third) case _ => logDebug(s"Not support window function: ${wExpression.getClass}") allSupported = false diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index ee495457edee..c0c6c175a879 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -944,7 +944,15 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr |from nation |order by n_regionkey, n_nationkey, n_lag |""".stripMargin + val sql1 = + """ + | select n_regionkey, n_nationkey, + | lag(n_nationkey, 1, n_nationkey) OVER (PARTITION BY n_regionkey ORDER BY n_nationkey) as n_lag + |from nation + |order by n_regionkey, n_nationkey, n_lag + |""".stripMargin compareResultsAgainstVanillaSpark(sql, true, { _ => }) + compareResultsAgainstVanillaSpark(sql1, true, { _ => }, false) } test("window lag with null value") {