Skip to content

Commit

Permalink
support time internval
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Jul 15, 2024
1 parent 997c6e3 commit 3b6c3e1
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,14 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr)
}

override def genDateAddTransformer(
attributeSeq: Seq[Attribute],
substraitExprName: String,
children: Seq[Expression],
expr: Expression): ExpressionTransformer = {
DateAddTransformer(attributeSeq, substraitExprName, children, expr).doTransform()
}

override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate

override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.expression

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class DateAddTransformer(
attributeSeq: Seq[Attribute],
substraitExprName: String,
children: Seq[Expression],
expr: Expression)
extends Logging {

def doTransform(): ExpressionTransformer = {
expr match {
case dateAdd: DateAdd =>
doDateAddTransform()
case timeAdd: TimeAdd =>
doTimeAddTransform()
case other => doDefaultTransorm()
}
}

private def doDateAddTransform(): ExpressionTransformer = {
children(1) match {
case extractDays: ExtractANSIIntervalDays =>
extractDays.child match {
case literal: Literal if literal.dataType.isInstanceOf[DayTimeIntervalType] =>
val (intVal, unitVal) = parseDayIntervalType(literal)
if (unitVal.equals("DAY")) {
val dateExpr =
ExpressionConverter.replaceWithExpressionTransformer(children(0), attributeSeq)
val daysExpr = LiteralTransformer(Literal(intVal, IntegerType))
GenericExpressionTransformer(substraitExprName, Seq(dateExpr, daysExpr), expr)
} else {
doDefaultTransorm()
}
case _ => doDefaultTransorm()
}
case _ => doDefaultTransorm()
}
}

private def doTimeAddTransform(): ExpressionTransformer = {
children(1) match {
case literal: Literal if literal.dataType.isInstanceOf[DayTimeIntervalType] =>
val (intVal, unitVal) = parseDayIntervalType(literal)
if (unitVal != null) {
val timeExpr =
ExpressionConverter.replaceWithExpressionTransformer(children(0), attributeSeq)

val intExpr = Literal(intVal, IntegerType)
val unitExpr = Literal(UTF8String.fromString(unitVal), StringType)
val newExpr = expr.withNewChildren(Seq(unitExpr, intExpr))
GenericExpressionTransformer(
substraitExprName,
Seq(LiteralTransformer(unitExpr), LiteralTransformer(intVal), timeExpr),
newExpr)
} else {
doDefaultTransorm()
}
case _ => doDefaultTransorm()
}
}

private def doDefaultTransorm(): ExpressionTransformer = {
// transorm it in a generic way
val childrenTransformers =
children.map(ExpressionConverter.replaceWithExpressionTransformer(_, attributeSeq))
GenericExpressionTransformer(substraitExprName, childrenTransformers, expr)
}

private def parseDayIntervalType(literal: Literal): (Integer, String) = {
literal.dataType match {
case dayIntervalType: DayTimeIntervalType =>
val intVal = literal.value.asInstanceOf[Long]
(dayIntervalType.startField, dayIntervalType.endField) match {
case (0, 0) => ((literal.value.asInstanceOf[Long] / 1000000 / 3600 / 24).toInt, "DAY")
case (1, 1) => ((literal.value.asInstanceOf[Long] / 1000000 / 3600).toInt, "HOUR")
case (2, 2) => ((literal.value.asInstanceOf[Long] / 1000000L / 60L).toInt, "MINUTE")
case (3, 3) => (literal.value.asInstanceOf[Long].toInt / 1000000, "SECOND")
case _ => (0, null)
}
case _ => (0, null)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2727,5 +2727,33 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
spark.sql("drop table test_tbl_4451")
}

test("array functions date add") {
spark.sql("create table tb_date(day Date) using parquet")
spark.sql("""
|insert into tb_date values
|(cast('2024-06-01' as Date)),
|(cast('2024-06-02' as Date)),
|(cast('2024-06-03' as Date)),
|(cast('2024-06-04' as Date)),
|(cast('2024-06-05' as Date))
|""".stripMargin)
val sql1 = """
|select * from tb_date where day between
|'2024-06-01' and
|cast('2024-06-01' as Date) + interval 2 day
|order by day
|""".stripMargin
compareResultsAgainstVanillaSpark(sql1, true, { _ => })
val sql2 = """
|select * from tb_date where day between
|'2024-06-01' and
|cast('2024-06-01' as Date) + interval 48 hour
|order by day
|""".stripMargin
compareResultsAgainstVanillaSpark(sql2, true, { _ => })

spark.sql("drop table tb_date")
}
}
// scalastyle:on line.size.limit
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,27 @@ class FunctionParserTimestampAdd : public FunctionParser
const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override
{
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
if (parsed_args.size() != 4)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly four arguments", getName());
if (parsed_args.size() < 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least three arguments", getName());

const auto & unit_field = substrait_func.arguments().at(0);
if (!unit_field.value().has_literal() || !unit_field.value().literal().has_string())
throw Exception(
ErrorCodes::BAD_ARGUMENTS, "Unsupported unit argument, should be a string literal, but: {}", unit_field.DebugString());

const auto & timezone_field = substrait_func.arguments().at(3);
if (!timezone_field.value().has_literal() || !timezone_field.value().literal().has_string())
String timezone;
if (parsed_args.size() == 4)
{
const auto & timezone_field = substrait_func.arguments().at(3);
if (!timezone_field.value().has_literal() || !timezone_field.value().literal().has_string())
throw Exception(
ErrorCodes::BAD_ARGUMENTS,
"Unsupported timezone_field argument, should be a string literal, but: {}",
timezone_field.DebugString());
timezone = timezone_field.value().literal().string();
}

const auto & unit = Poco::toUpper(unit_field.value().literal().string());
auto timezone = timezone_field.value().literal().string();

std::string ch_function_name;
if (unit == "MICROSECOND")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,18 @@ trait SparkPlanExecApi {
throw new GlutenNotSupportException("PreciseTimestampConversion is not supported")
}

// For date_add(cast('2001-01-01' as Date), interval 1 day), backends may handle it in different
// ways
def genDateAddTransformer(
attributeSeq: Seq[Attribute],
substraitExprName: String,
children: Seq[Expression],
expr: Expression): ExpressionTransformer = {
val childrenTransformers =
children.map(ExpressionConverter.replaceWithExpressionTransformer(_, attributeSeq))
GenericExpressionTransformer(substraitExprName, childrenTransformers, expr)
}

/**
* Generate ShuffleDependency for ColumnarShuffleExchangeExec.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,20 @@ object ExpressionConverter extends SQLConfHelper with Logging {
LiteralTransformer(Literal(Math.E))
case p: Pi =>
LiteralTransformer(Literal(Math.PI))
case dateAdd: DateAdd =>
BackendsApiManager.getSparkPlanExecApiInstance.genDateAddTransformer(
attributeSeq,
substraitExprName,
dateAdd.children,
dateAdd
)
case timeAdd: TimeAdd =>
BackendsApiManager.getSparkPlanExecApiInstance.genDateAddTransformer(
attributeSeq,
substraitExprName,
timeAdd.children,
timeAdd
)
case expr =>
GenericExpressionTransformer(
substraitExprName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ object ExpressionMappings {
Sig[Second](EXTRACT),
Sig[FromUnixTime](FROM_UNIXTIME),
Sig[DateAdd](DATE_ADD),
Sig[TimeAdd](TIMESTAMP_ADD),
Sig[DateSub](DATE_SUB),
Sig[DateDiff](DATE_DIFF),
Sig[ToUnixTimestamp](TO_UNIX_TIMESTAMP),
Expand Down

0 comments on commit 3b6c3e1

Please sign in to comment.