diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index ffe8f2f9e445d..9380e88159d4a 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -897,6 +897,14 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr) } + override def genDateAddTransformer( + attributeSeq: Seq[Attribute], + substraitExprName: String, + children: Seq[Expression], + expr: Expression): ExpressionTransformer = { + DateAddTransformer(attributeSeq, substraitExprName, children, expr).doTransform() + } + override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/expression/DateAddTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/DateAddTransformer.scala new file mode 100644 index 0000000000000..9867b981090b9 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/expression/DateAddTransformer.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.expression + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +case class DateAddTransformer( + attributeSeq: Seq[Attribute], + substraitExprName: String, + children: Seq[Expression], + expr: Expression) + extends Logging { + + def doTransform(): ExpressionTransformer = { + expr match { + case dateAdd: DateAdd => + doDateAddTransform() + case timeAdd: TimeAdd => + doTimeAddTransform() + case other => doDefaultTransorm() + } + } + + private def doDateAddTransform(): ExpressionTransformer = { + children(1) match { + case extractDays: ExtractANSIIntervalDays => + extractDays.child match { + case literal: Literal if literal.dataType.isInstanceOf[DayTimeIntervalType] => + val (intVal, unitVal) = parseDayIntervalType(literal) + if (unitVal.equals("DAY")) { + val dateExpr = + ExpressionConverter.replaceWithExpressionTransformer(children(0), attributeSeq) + val daysExpr = LiteralTransformer(Literal(intVal, IntegerType)) + GenericExpressionTransformer(substraitExprName, Seq(dateExpr, daysExpr), expr) + } else { + doDefaultTransorm() + } + case _ => doDefaultTransorm() + } + case _ => doDefaultTransorm() + } + } + + private def doTimeAddTransform(): ExpressionTransformer = { + children(1) match { + case literal: Literal if literal.dataType.isInstanceOf[DayTimeIntervalType] => + val (intVal, unitVal) = parseDayIntervalType(literal) + if (unitVal != null) { + val timeExpr = + ExpressionConverter.replaceWithExpressionTransformer(children(0), attributeSeq) + + val intExpr = Literal(intVal, IntegerType) + val unitExpr = Literal(UTF8String.fromString(unitVal), StringType) + val newExpr = expr.withNewChildren(Seq(unitExpr, intExpr)) + GenericExpressionTransformer( + substraitExprName, + Seq(LiteralTransformer(unitExpr), LiteralTransformer(intVal), timeExpr), + newExpr) + } else { + doDefaultTransorm() + } + case _ => doDefaultTransorm() + } + } + + private def doDefaultTransorm(): ExpressionTransformer = { + // transorm it in a generic way + val childrenTransformers = + children.map(ExpressionConverter.replaceWithExpressionTransformer(_, attributeSeq)) + GenericExpressionTransformer(substraitExprName, childrenTransformers, expr) + } + + private def parseDayIntervalType(literal: Literal): (Integer, String) = { + literal.dataType match { + case dayIntervalType: DayTimeIntervalType => + val intVal = literal.value.asInstanceOf[Long] + (dayIntervalType.startField, dayIntervalType.endField) match { + case (0, 0) => ((literal.value.asInstanceOf[Long] / 1000000 / 3600 / 24).toInt, "DAY") + case (1, 1) => ((literal.value.asInstanceOf[Long] / 1000000 / 3600).toInt, "HOUR") + case (2, 2) => ((literal.value.asInstanceOf[Long] / 1000000L / 60L).toInt, "MINUTE") + case (3, 3) => (literal.value.asInstanceOf[Long].toInt / 1000000, "SECOND") + case _ => (0, null) + } + case _ => (0, null) + } + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index bd8a37d92c4e3..db8557cd8feb5 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2727,5 +2727,33 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr compareResultsAgainstVanillaSpark(select_sql, true, { _ => }) spark.sql("drop table test_tbl_4451") } + + test("array functions date add") { + spark.sql("create table tb_date(day Date) using parquet") + spark.sql(""" + |insert into tb_date values + |(cast('2024-06-01' as Date)), + |(cast('2024-06-02' as Date)), + |(cast('2024-06-03' as Date)), + |(cast('2024-06-04' as Date)), + |(cast('2024-06-05' as Date)) + |""".stripMargin) + val sql1 = """ + |select * from tb_date where day between + |'2024-06-01' and + |cast('2024-06-01' as Date) + interval 2 day + |order by day + |""".stripMargin + compareResultsAgainstVanillaSpark(sql1, true, { _ => }) + val sql2 = """ + |select * from tb_date where day between + |'2024-06-01' and + |cast('2024-06-01' as Date) + interval 48 hour + |order by day + |""".stripMargin + compareResultsAgainstVanillaSpark(sql2, true, { _ => }) + + spark.sql("drop table tb_date") + } } // scalastyle:on line.size.limit diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp index 6e92a7b928bc6..af81e2bd7455b 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp @@ -47,23 +47,27 @@ class FunctionParserTimestampAdd : public FunctionParser const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override { auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); - if (parsed_args.size() != 4) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly four arguments", getName()); + if (parsed_args.size() < 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least three arguments", getName()); const auto & unit_field = substrait_func.arguments().at(0); if (!unit_field.value().has_literal() || !unit_field.value().literal().has_string()) throw Exception( ErrorCodes::BAD_ARGUMENTS, "Unsupported unit argument, should be a string literal, but: {}", unit_field.DebugString()); - const auto & timezone_field = substrait_func.arguments().at(3); - if (!timezone_field.value().has_literal() || !timezone_field.value().literal().has_string()) + String timezone; + if (parsed_args.size() == 4) + { + const auto & timezone_field = substrait_func.arguments().at(3); + if (!timezone_field.value().has_literal() || !timezone_field.value().literal().has_string()) throw Exception( ErrorCodes::BAD_ARGUMENTS, "Unsupported timezone_field argument, should be a string literal, but: {}", timezone_field.DebugString()); + timezone = timezone_field.value().literal().string(); + } const auto & unit = Poco::toUpper(unit_field.value().literal().string()); - auto timezone = timezone_field.value().literal().string(); std::string ch_function_name; if (unit == "MICROSECOND") diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index a69d41d00c12d..73158f6538af8 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -322,6 +322,18 @@ trait SparkPlanExecApi { throw new GlutenNotSupportException("PreciseTimestampConversion is not supported") } + // For date_add(cast('2001-01-01' as Date), interval 1 day), backends may handle it in different + // ways + def genDateAddTransformer( + attributeSeq: Seq[Attribute], + substraitExprName: String, + children: Seq[Expression], + expr: Expression): ExpressionTransformer = { + val childrenTransformers = + children.map(ExpressionConverter.replaceWithExpressionTransformer(_, attributeSeq)) + GenericExpressionTransformer(substraitExprName, childrenTransformers, expr) + } + /** * Generate ShuffleDependency for ColumnarShuffleExchangeExec. * diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index c285c11945c75..792aebcdc5f06 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -686,6 +686,21 @@ object ExpressionConverter extends SQLConfHelper with Logging { LiteralTransformer(Literal(Math.E)) case p: Pi => LiteralTransformer(Literal(Math.PI)) + case dateAdd: DateAdd => + logError(s"xxx handle dateAdd") + BackendsApiManager.getSparkPlanExecApiInstance.genDateAddTransformer( + attributeSeq, + substraitExprName, + dateAdd.children, + dateAdd + ) + case timeAdd: TimeAdd => + BackendsApiManager.getSparkPlanExecApiInstance.genDateAddTransformer( + attributeSeq, + substraitExprName, + timeAdd.children, + timeAdd + ) case expr => GenericExpressionTransformer( substraitExprName, diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index 77e85b3548665..9012ca8048d63 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -177,6 +177,7 @@ object ExpressionMappings { Sig[Second](EXTRACT), Sig[FromUnixTime](FROM_UNIXTIME), Sig[DateAdd](DATE_ADD), + Sig[TimeAdd](TIMESTAMP_ADD), Sig[DateSub](DATE_SUB), Sig[DateDiff](DATE_DIFF), Sig[ToUnixTimestamp](TO_UNIX_TIMESTAMP),