From 14a3299fa43809a26400d1feee837a40b4de77a4 Mon Sep 17 00:00:00 2001 From: lgbo Date: Tue, 9 Jan 2024 09:26:24 +0800 Subject: [PATCH] [GLUTEN-4302][CH] Fixed bugs about rewriting date comparison (#4303) [CH] Fixed bugs about rewriting date comparison --- .../clickhouse/CHSparkPlanExecApi.scala | 9 ++---- .../RewriteDateTimestampComparisonRule.scala | 31 +++++++++---------- .../scala/io/glutenproject/GlutenConfig.scala | 4 ++- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala index c16ff3a71073..e654a97d69dd 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -344,12 +344,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { * @return */ override def genExtendedAnalyzers(): List[SparkSession => Rule[LogicalPlan]] = { - val analyzers = List(spark => new ClickHouseAnalysis(spark, spark.sessionState.conf)) - if (GlutenConfig.getConf.enableRewriteDateTimestampComparison) { - analyzers :+ (spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) - } else { - analyzers - } + List( + spark => new ClickHouseAnalysis(spark, spark.sessionState.conf), + spark => new RewriteDateTimestampComparisonRule(spark, spark.sessionState.conf)) } /** diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/RewriteDateTimestampComparisonRule.scala b/gluten-core/src/main/scala/io/glutenproject/extension/RewriteDateTimestampComparisonRule.scala index 5bbc9fb1820f..3d3bded41881 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/RewriteDateTimestampComparisonRule.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/RewriteDateTimestampComparisonRule.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.extension +import io.glutenproject.GlutenConfig + import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions._ @@ -46,7 +48,11 @@ class RewriteDateTimestampComparisonRule(session: SparkSession, conf: SQLConf) } override def apply(plan: LogicalPlan): LogicalPlan = { - if (plan.resolved) { + if ( + plan.resolved && + GlutenConfig.getConf.enableGluten && + GlutenConfig.getConf.enableRewriteDateTimestampComparison + ) { visitPlan(plan) } else { plan @@ -193,6 +199,7 @@ class RewriteDateTimestampComparisonRule(session: SparkSession, conf: SQLConf) Add(toUnixTimestampExpr, adjustExpr) } + // rewrite an expressiont that converts unix timestamp to date back to unix timestamp private def rewriteUnixTimestampToDate(expr: Expression): Expression = { expr match { case toDate: ParseToDate => @@ -302,25 +309,15 @@ class RewriteDateTimestampComparisonRule(session: SparkSession, conf: SQLConf) return cmp } val zoneId = getTimeZoneId(cmp.left) - val timestampLeft = rewriteUnixTimestampToDate(cmp.left) - val adjustedOffset = Literal(TimeUnitToSeconds(timeUnit.get), timestampLeft.dataType) - val addjustedOffsetExpr = Remainder(timestampLeft, adjustedOffset) - val newLeft = Subtract(timestampLeft, addjustedOffsetExpr) + val newLeft = rewriteUnixTimestampToDate(cmp.left) + val adjustedOffset = Literal(TimeUnitToSeconds(timeUnit.get), newLeft.dataType) val newRight = rewriteConstDate(cmp.right, timeUnit.get, zoneId, 0) - EqualTo(newLeft, newRight) + val leftBound = GreaterThanOrEqual(newLeft, newRight) + val rigtBound = LessThan(newLeft, Add(newRight, adjustedOffset)) + And(leftBound, rigtBound) } private def rewriteEqualNullSafe(cmp: EqualNullSafe): Expression = { - val timeUnit = getDateTimeUnit(cmp.left) - if (timeUnit.isEmpty) { - return cmp - } - val zoneId = getTimeZoneId(cmp.left) - val timestampLeft = rewriteUnixTimestampToDate(cmp.left) - val adjustedOffset = Literal(TimeUnitToSeconds(timeUnit.get), timestampLeft.dataType) - val addjustedOffsetExpr = Remainder(timestampLeft, adjustedOffset) - val newLeft = Subtract(timestampLeft, addjustedOffsetExpr) - val newRight = rewriteConstDate(cmp.right, timeUnit.get, zoneId, 0) - EqualNullSafe(newLeft, newRight) + cmp } } diff --git a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala index 9745b02b070a..846b3ab00f89 100644 --- a/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala +++ b/shims/common/src/main/scala/io/glutenproject/GlutenConfig.scala @@ -39,6 +39,8 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableAnsiMode: Boolean = conf.ansiEnabled + def enableGluten: Boolean = conf.getConf(GLUTEN_ENABLED) + // FIXME the option currently controls both JVM and native validation against a Substrait plan. def enableNativeValidation: Boolean = conf.getConf(NATIVE_VALIDATION_ENABLED) @@ -1411,7 +1413,7 @@ object GlutenConfig { buildConf("spark.gluten.sql.rewrite.dateTimestampComparison") .internal() .doc("Rewrite the comparision between date and timestamp to timestamp comparison." - + "For example `fron_unixtime(ts) > date` will be rewritten to `ts > to_unixtime(date)`") + + "For example `from_unixtime(ts) > date` will be rewritten to `ts > to_unixtime(date)`") .booleanConf .createWithDefault(true)