From 399a91b12d572375a7c8c7a617c7f04e2e3612e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E6=89=AC?= <654010905@qq.com> Date: Sun, 15 Sep 2024 17:24:11 +0800 Subject: [PATCH] [CH] Fix GlutenLiteralExpressionSuite and GlutenMathExpressionsSuite (#7235) * fix failed uts * Update CommonScalarFunctionParser.cpp * override checkResult for ch backend --- .../gluten/utils/CHExpressionUtil.scala | 3 +- .../CommonScalarFunctionParser.cpp | 2 - .../Parser/scalar_function_parser/shift.cpp | 105 ++++++++++++++++++ .../shiftRightUnsigned.cpp | 19 ++-- .../clickhouse/ClickHouseTestSettings.scala | 26 +---- .../GlutenLiteralExpressionSuite.scala | 38 ++++++- .../GlutenMathExpressionsSuite.scala | 51 +++++++++ 7 files changed, 208 insertions(+), 36 deletions(-) create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index 3c9fa9888a5a..645189310a52 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -214,6 +214,7 @@ object CHExpressionUtil { STACK -> DefaultValidator(), TRANSFORM_KEYS -> DefaultValidator(), TRANSFORM_VALUES -> DefaultValidator(), - RAISE_ERROR -> DefaultValidator() + RAISE_ERROR -> DefaultValidator(), + WIDTH_BUCKET -> DefaultValidator() ) } diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp index 37282104c644..88e5d7ea8a39 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp @@ -102,8 +102,6 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Sign, sign, sign); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Radians, radians, radians); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Greatest, greatest, sparkGreatest); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Least, least, sparkLeast); -REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftLeft, shiftleft, bitShiftLeft); -REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ShiftRight, shiftright, bitShiftRight); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rand, rand, randCanonical); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bin, bin, sparkBin); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rint, rint, sparkRint); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp new file mode 100644 index 000000000000..663bf5e26388 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/shift.cpp @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ + +class FunctionParserShiftBase : public FunctionParser +{ +public: + explicit FunctionParserShiftBase(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } + ~FunctionParserShiftBase() override = default; + + virtual String getCHFunctionName() const = 0; + + const ActionsDAG::Node * parse( + const substrait::Expression_ScalarFunction & substrait_func, + ActionsDAG & actions_dag) const override + { + /// parse spark shiftxxx(expr, n) as + /// If expr has long type -> CH bitShiftxxx(expr, pmod(n, 64)) + /// Otherwise -> CH bitShiftxxx(expr, pmod(n, 32)) + auto parsed_args = parseFunctionArguments(substrait_func, actions_dag); + if (parsed_args.size() != 2) + throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); + + + auto input_type = removeNullable(parsed_args[0]->result_type); + WhichDataType which(input_type); + const ActionsDAG::Node * base_node = nullptr; + if (which.isInt64()) + { + base_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 64); + } + else if (which.isInt32()) + { + base_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 32); + } + else + throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "First argument for function {} must be an long or integer", getName()); + + const auto * pmod_node = toFunctionNode(actions_dag, "pmod", {parsed_args[1], base_node}); + auto ch_function_name = getCHFunctionName(); + const auto * shift_node = toFunctionNode(actions_dag, ch_function_name, {parsed_args[0], pmod_node}); + return convertNodeTypeIfNeeded(substrait_func, shift_node, actions_dag); + } +}; + +class FunctionParserShiftLeft : public FunctionParserShiftBase +{ +public: + explicit FunctionParserShiftLeft(SerializedPlanParser * plan_parser_) : FunctionParserShiftBase(plan_parser_) { } + ~FunctionParserShiftLeft() override = default; + + static constexpr auto name = "shiftleft"; + String getName() const override { return name; } + + String getCHFunctionName() const override { return "bitShiftLeft"; } +}; +static FunctionParserRegister register_shiftleft; + +class FunctionParserShiftRight: public FunctionParserShiftBase +{ +public: + explicit FunctionParserShiftRight(SerializedPlanParser * plan_parser_) : FunctionParserShiftBase(plan_parser_) { } + ~FunctionParserShiftRight() override = default; + + static constexpr auto name = "shiftright"; + String getName() const override { return name; } + + String getCHFunctionName() const override { return "bitShiftRight"; } +}; +static FunctionParserRegister register_shiftright; + + +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp index 28288461a1da..ca88e852234b 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/shiftRightUnsigned.cpp @@ -43,9 +43,9 @@ class FunctionParserShiftRightUnsigned : public FunctionParser { /// parse shiftrightunsigned(a, b) as /// if (isInteger(a)) - /// bitShiftRight(a::UInt32, b::UInt32) + /// bitShiftRight(a::UInt32, pmod(b, 32)) /// else if (isLong(a)) - /// bitShiftRight(a::UInt64, b::UInt64) + /// bitShiftRight(a::UInt64, pmod(b, 32)) /// else /// throw Exception @@ -55,26 +55,27 @@ class FunctionParserShiftRightUnsigned : public FunctionParser const auto * a = parsed_args[0]; const auto * b = parsed_args[1]; - const auto * new_a = a; - const auto * new_b = b; WhichDataType which(removeNullable(a->result_type)); + const ActionsDAG::Node * base_node = nullptr; + const ActionsDAG::Node * unsigned_a_node = nullptr; if (which.isInt32()) { + base_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 32); const auto * uint32_type_node = addColumnToActionsDAG(actions_dag, std::make_shared(), "Nullable(UInt32)"); - new_a = toFunctionNode(actions_dag, "CAST", {a, uint32_type_node}); - new_b = toFunctionNode(actions_dag, "CAST", {b, uint32_type_node}); + unsigned_a_node = toFunctionNode(actions_dag, "CAST", {a, uint32_type_node}); } else if (which.isInt64()) { + base_node = addColumnToActionsDAG(actions_dag, std::make_shared(), 64); const auto * uint64_type_node = addColumnToActionsDAG(actions_dag, std::make_shared(), "Nullable(UInt64)"); - new_a = toFunctionNode(actions_dag, "CAST", {a, uint64_type_node}); - new_b = toFunctionNode(actions_dag, "CAST", {b, uint64_type_node}); + unsigned_a_node = toFunctionNode(actions_dag, "CAST", {a, uint64_type_node}); } else throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Function {} requires integer or long as first argument", getName()); - const auto * result = toFunctionNode(actions_dag, "bitShiftRight", {new_a, new_b}); + const auto * pmod_node = toFunctionNode(actions_dag, "pmod", {b, base_node}); + const auto * result = toFunctionNode(actions_dag, "bitShiftRight", {unsigned_a_node, pmod_node}); return convertNodeTypeIfNeeded(substrait_func, result, actions_dag); } }; diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index f40007957507..c260d3f8029b 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -788,32 +788,12 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-35728: Check multiply/divide of day-time intervals of any fields by numeric") .exclude("SPARK-35778: Check multiply/divide of year-month intervals of any fields by numeric") enableSuite[GlutenLiteralExpressionSuite] - .exclude("null") .exclude("default") - .exclude("decimal") - .exclude("array") - .exclude("seq") - .exclude("map") - .exclude("struct") - .exclude("SPARK-35664: construct literals from java.time.LocalDateTime") - .exclude("SPARK-34605: construct literals from java.time.Duration") - .exclude("SPARK-34605: construct literals from arrays of java.time.Duration") - .exclude("SPARK-34615: construct literals from java.time.Period") - .exclude("SPARK-34615: construct literals from arrays of java.time.Period") - .exclude("SPARK-35871: Literal.create(value, dataType) should support fields") .exclude("SPARK-37967: Literal.create support ObjectType") enableSuite[GlutenMathExpressionsSuite] - .exclude("tanh") - .exclude("unhex") - .exclude("atan2") - .exclude("round/bround/floor/ceil") - .exclude("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM") - .exclude("SPARK-35926: Support YearMonthIntervalType in width-bucket function") - .exclude("SPARK-35925: Support DayTimeIntervalType in width-bucket function") - .exclude("SPARK-37388: width_bucket") - .exclude("shift left") - .exclude("shift right") - .exclude("shift right unsigned") + .exclude("unhex") // https://github.com/apache/incubator-gluten/issues/7232 + .exclude("round/bround/floor/ceil") // https://github.com/apache/incubator-gluten/issues/7233 + .exclude("atan2") // https://github.com/apache/incubator-gluten/issues/7233 enableSuite[GlutenMiscExpressionsSuite] enableSuite[GlutenNondeterministicSuite] .exclude("MonotonicallyIncreasingID") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala index 556d185af078..f81ef0b6ff3a 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenLiteralExpressionSuite.scala @@ -17,5 +17,41 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.GlutenTestsTrait +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval -class GlutenLiteralExpressionSuite extends LiteralExpressionSuite with GlutenTestsTrait {} +import java.nio.charset.StandardCharsets +import java.time.{Instant, LocalDate} + +class GlutenLiteralExpressionSuite extends LiteralExpressionSuite with GlutenTestsTrait { + testGluten("default") { + checkEvaluation(Literal.default(BooleanType), false) + checkEvaluation(Literal.default(ByteType), 0.toByte) + checkEvaluation(Literal.default(ShortType), 0.toShort) + checkEvaluation(Literal.default(IntegerType), 0) + checkEvaluation(Literal.default(LongType), 0L) + checkEvaluation(Literal.default(FloatType), 0.0f) + checkEvaluation(Literal.default(DoubleType), 0.0) + checkEvaluation(Literal.default(StringType), "") + checkEvaluation(Literal.default(BinaryType), "".getBytes(StandardCharsets.UTF_8)) + checkEvaluation(Literal.default(DecimalType.USER_DEFAULT), Decimal(0)) + checkEvaluation(Literal.default(DecimalType.SYSTEM_DEFAULT), Decimal(0)) + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "false") { + checkEvaluation(Literal.default(DateType), DateTimeUtils.toJavaDate(0)) + checkEvaluation(Literal.default(TimestampType), DateTimeUtils.toJavaTimestamp(0L)) + } + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { + checkEvaluation(Literal.default(DateType), LocalDate.ofEpochDay(0)) + checkEvaluation(Literal.default(TimestampType), Instant.ofEpochSecond(0)) + } + checkEvaluation(Literal.default(CalendarIntervalType), new CalendarInterval(0, 0, 0L)) + checkEvaluation(Literal.default(YearMonthIntervalType()), 0) + checkEvaluation(Literal.default(DayTimeIntervalType()), 0L) + checkEvaluation(Literal.default(ArrayType(StringType)), Array()) + checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map()) + checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row("")) + } +} diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala index 7085b70ae38c..a8716b6effac 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala @@ -18,11 +18,44 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.gluten.utils.BackendTestUtils +import org.apache.spark.sql.GlutenQueryTestUtil.isNaNOrInf import org.apache.spark.sql.GlutenTestsTrait import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ +import org.apache.commons.math3.util.Precision + +import java.nio.charset.StandardCharsets + class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTrait { + override protected def checkResult( + result: Any, + expected: Any, + exprDataType: DataType, + exprNullable: Boolean): Boolean = { + if (BackendTestUtils.isVeloxBackendLoaded()) { + super.checkResult(result, expected, exprDataType, exprNullable) + } else { + // The result is null for a non-nullable expression + assert(result != null || exprNullable, "exprNullable should be true if result is null") + (result, expected) match { + case (result: Double, expected: Double) => + if ( + (isNaNOrInf(result) || isNaNOrInf(expected)) + || (result == -0.0) || (expected == -0.0) + ) { + java.lang.Double.doubleToRawLongBits(result) == + java.lang.Double.doubleToRawLongBits(expected) + } else { + Precision.equalsWithRelativeTolerance(result, expected, 0.00001d) || + Precision.equals(result, expected, 0.00001d) + } + case _ => + super.checkResult(result, expected, exprDataType, exprNullable) + } + } + } + testGluten("round/bround/floor/ceil") { val scales = -6 to 6 val doublePi: Double = math.Pi @@ -284,4 +317,22 @@ class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTr checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(3.1411), Literal(-3))), Decimal(1000)) checkEvaluation(checkDataTypeAndCast(RoundCeil(Literal(135.135), Literal(-2))), Decimal(200)) } + + testGluten("unhex") { + checkEvaluation(Unhex(Literal.create(null, StringType)), null) + checkEvaluation(Unhex(Literal("737472696E67")), "string".getBytes(StandardCharsets.UTF_8)) + checkEvaluation(Unhex(Literal("")), new Array[Byte](0)) + checkEvaluation(Unhex(Literal("F")), Array[Byte](15)) + checkEvaluation(Unhex(Literal("ff")), Array[Byte](-1)) + +// checkEvaluation(Unhex(Literal("GG")), null) + checkEvaluation(Unhex(Literal("123")), Array[Byte](1, 35)) + checkEvaluation(Unhex(Literal("12345")), Array[Byte](1, 35, 69)) + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "三重的".getBytes(StandardCharsets.UTF_8)) +// checkEvaluation(Unhex(Literal("三重的")), null) + // scalastyle:on + checkConsistencyBetweenInterpretedAndCodegen(Unhex, StringType) + } }