From 89015c91525734d3080e779f2b944e5982955171 Mon Sep 17 00:00:00 2001 From: Zhichao Zhang Date: Wed, 24 Apr 2024 23:52:05 +0800 Subject: [PATCH] [GLUTEN-5512][CH] Fix the incorrect transformer for the round function with the decimal data type (#5513) [CH] Fix the incorrect transformer for the round function with the decimal data type. When transforming the round function with the decimal data type, it only transformers the `child` expression, but missing the `scale` expression, which will lead to the incorrect results when executing the `round(decimal_data, x)`. Close #5512. --- .../GlutenClickHouseTPCHBucketSuite.scala | 23 ++++-- .../GlutenFunctionValidateSuite.scala | 55 +++++++++++-- .../velox/VeloxSparkPlanExecApi.scala | 10 +-- .../Functions/SparkFunctionRoundHalfUp.h | 81 ++++++++++++++----- .../gluten/backendsapi/SparkPlanExecApi.scala | 7 -- .../expression/ExpressionConverter.scala | 5 +- .../clickhouse/ClickHouseTestSettings.scala | 1 - .../GlutenMathExpressionsSuite.scala | 10 ++- .../clickhouse/ClickHouseTestSettings.scala | 1 - .../GlutenMathExpressionsSuite.scala | 10 ++- .../sql/shims/spark33/Spark33Shims.scala | 6 +- 11 files changed, 151 insertions(+), 58 deletions(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala index 8695e9483dfb..79a708ce50eb 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHBucketSuite.scala @@ -17,7 +17,7 @@ package org.apache.gluten.execution import org.apache.spark.SparkConf -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, TestUtils} import org.apache.spark.sql.execution.InputIteratorTransformer import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.SortAggregateExec @@ -582,11 +582,18 @@ class GlutenClickHouseTPCHBucketSuite assert(plans.size == expectedCount) } - def checkResult(df: DataFrame, exceptedResult: Array[Row]): Unit = { + def checkResult(df: DataFrame, exceptedResult: Seq[Row]): Unit = { // check the result val result = df.collect() assert(result.size == exceptedResult.size) - result.equals(exceptedResult) + val sortedRes = result.map { + s => + Row.fromSeq(s.toSeq.map { + case a: mutable.WrappedArray[_] => a.sortBy(_.toString.toInt) + case o => o + }) + } + TestUtils.compareAnswers(sortedRes, exceptedResult) } val SQL = @@ -600,10 +607,10 @@ class GlutenClickHouseTPCHBucketSuite checkResult( df, Array( - Row(1, "N", mutable.WrappedArray.make(Array(3, 6, 1, 5, 2, 4))), + Row(1, "N", mutable.WrappedArray.make(Array(1, 2, 3, 4, 5, 6))), Row(2, "N", mutable.WrappedArray.make(Array(1))), - Row(3, "A", mutable.WrappedArray.make(Array(6, 4, 3))), - Row(3, "R", mutable.WrappedArray.make(Array(2, 5, 1))), + Row(3, "A", mutable.WrappedArray.make(Array(3, 4, 6))), + Row(3, "R", mutable.WrappedArray.make(Array(1, 2, 5))), Row(4, "N", mutable.WrappedArray.make(Array(1))) ) ) @@ -645,11 +652,11 @@ class GlutenClickHouseTPCHBucketSuite checkResult( df, Array( - Row("A", 3, mutable.WrappedArray.make(Array(6, 4, 3))), + Row("A", 3, mutable.WrappedArray.make(Array(3, 4, 6))), Row("A", 5, mutable.WrappedArray.make(Array(3))), Row("A", 6, mutable.WrappedArray.make(Array(1))), Row("A", 33, mutable.WrappedArray.make(Array(1, 2, 3))), - Row("A", 37, mutable.WrappedArray.make(Array(2, 3, 1))) + Row("A", 37, mutable.WrappedArray.make(Array(1, 2, 3))) ) ) checkHashAggregateCount(df, 1) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 8f7dcd6456c2..7b52a970ef08 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -20,7 +20,7 @@ import org.apache.gluten.GlutenConfig import org.apache.gluten.utils.UTSystemParameters import org.apache.spark.SparkConf -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, TestUtils} import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -495,13 +495,54 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS } test("test round issue: https://github.com/oap-project/gluten/issues/3462") { - runQueryAndCompare( - "select round(0.41875d * id , 4) from range(10);" - )(checkGlutenOperatorMatch[ProjectExecTransformer]) + def checkResult(df: DataFrame, exceptedResult: Seq[Row]): Unit = { + // check the result + val result = df.collect() + assert(result.size == exceptedResult.size) + TestUtils.compareAnswers(result, exceptedResult) + } - runQueryAndCompare( - "select round(0.41875f * id , 4) from range(10);" - )(checkGlutenOperatorMatch[ProjectExecTransformer]) + runSql("select round(0.41875d * id , 4) from range(10);")( + df => { + checkGlutenOperatorMatch[ProjectExecTransformer](df) + + checkResult( + df, + Seq( + Row(0.0), + Row(0.4188), + Row(0.8375), + Row(1.2563), + Row(1.675), + Row(2.0938), + Row(2.5125), + Row(2.9313), + Row(3.35), + Row(3.7688) + ) + ) + }) + + runSql("select round(0.41875f * id , 4) from range(10);")( + df => { + checkGlutenOperatorMatch[ProjectExecTransformer](df) + + checkResult( + df, + Seq( + Row(0.0f), + Row(0.4188f), + Row(0.8375f), + Row(1.2562f), + Row(1.675f), + Row(2.0938f), + Row(2.5125f), + Row(2.9312f), + Row(3.35f), + Row(3.7688f) + ) + ) + }) } test("test date comparision expression override") { diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 7463c6340f75..e9bd47bec0bd 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayExists, ArrayFilter, ArrayForAll, ArrayTransform, Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, LambdaFunction, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim, TryEval, Uuid} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter} import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.JoinType @@ -572,14 +572,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { * * Expressions. */ - /** Generates a transformer for decimal round. */ - override def genDecimalRoundTransformer( - substraitExprName: String, - child: ExpressionTransformer, - original: Round): ExpressionTransformer = { - DecimalRoundTransformer(substraitExprName, child, original) - } - /** Generate StringSplit transformer. */ override def genStringSplitTransformer( substraitExprName: String, diff --git a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h index ab4faf23575e..47135aabd94f 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h @@ -18,32 +18,77 @@ #include - namespace local_engine { using namespace DB; template -class BaseFloatRoundingHalfUpComputation +class BaseFloatRoundingHalfUpComputation; + +template <> +class BaseFloatRoundingHalfUpComputation +{ +public: + using ScalarType = Float32; + using VectorType = __m128; + static const size_t data_count = 4; + + static VectorType load(const ScalarType * in) { return _mm_loadu_ps(in); } + static VectorType load1(const ScalarType in) { return _mm_load1_ps(&in); } + static void store(ScalarType * out, VectorType val) { _mm_storeu_ps(out, val);} + static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_ps(val, scale); } + static VectorType divide(VectorType val, VectorType scale) { return _mm_div_ps(val, scale); } + template static VectorType apply(VectorType val) + { + ScalarType tempFloatsIn[data_count]; + ScalarType tempFloatsOut[data_count]; + store(tempFloatsIn, val); + for (size_t i = 0; i < data_count; ++i) + tempFloatsOut[i] = std::roundf(tempFloatsIn[i]); + + return load(tempFloatsOut); + } + + static VectorType prepare(size_t scale) + { + return load1(scale); + } +}; + +template <> +class BaseFloatRoundingHalfUpComputation { public: - using ScalarType = T; - using VectorType = Float64; - static const size_t data_count = 1; - - static VectorType load(const ScalarType * in) { return static_cast(*in); } - static VectorType load1(ScalarType in) { return in; } - static ScalarType store(ScalarType * out, VectorType val) { return *out = static_cast(val); } - static VectorType multiply(VectorType val, VectorType scale) { return val * scale; } - static VectorType divide(VectorType val, VectorType scale) { return val / scale; } - static VectorType apply(VectorType val) { return round(val); } - static VectorType prepare(size_t scale) { return load1(scale); } + using ScalarType = Float64; + using VectorType = __m128d; + static const size_t data_count = 2; + + static VectorType load(const ScalarType * in) { return _mm_loadu_pd(in); } + static VectorType load1(const ScalarType in) { return _mm_load1_pd(&in); } + static void store(ScalarType * out, VectorType val) { _mm_storeu_pd(out, val);} + static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_pd(val, scale); } + static VectorType divide(VectorType val, VectorType scale) { return _mm_div_pd(val, scale); } + template static VectorType apply(VectorType val) + { + ScalarType tempFloatsIn[data_count]; + ScalarType tempFloatsOut[data_count]; + store(tempFloatsIn, val); + for (size_t i = 0; i < data_count; ++i) + tempFloatsOut[i] = std::round(tempFloatsIn[i]); + + return load(tempFloatsOut); + } + + static VectorType prepare(size_t scale) + { + return load1(scale); + } }; /** Implementation of low-level round-off functions for floating-point values. */ -template +template class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation { using Base = BaseFloatRoundingHalfUpComputation; @@ -58,7 +103,7 @@ class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation else if (scale_mode == ScaleMode::Negative) val = Base::divide(val, scale); - val = Base::apply(val); + val = Base::template apply(val); if (scale_mode == ScaleMode::Positive) val = Base::divide(val, scale); @@ -72,13 +117,13 @@ class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation /** Implementing high-level rounding functions. */ -template +template struct FloatRoundingHalfUpImpl { private: static_assert(!is_decimal); - using Op = FloatRoundingHalfUpComputation; + using Op = FloatRoundingHalfUpComputation; using Data = std::array; using ColumnType = ColumnVector; using Container = typename ColumnType::Container; @@ -125,7 +170,7 @@ struct DispatcherRoundingHalfUp { template using FunctionRoundingImpl = std::conditional_t, - FloatRoundingHalfUpImpl, + FloatRoundingHalfUpImpl, IntegerRoundingImpl>; static ColumnPtr apply(const IColumn * col_general, Scale scale_arg) diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 8dfb1e641ce7..1cffe3cb166e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -415,13 +415,6 @@ trait SparkPlanExecApi { */ def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] - def genDecimalRoundTransformer( - substraitExprName: String, - child: ExpressionTransformer, - original: Round): ExpressionTransformer = { - GenericExpressionTransformer(substraitExprName, Seq(child), original) - } - def genGetStructFieldTransformer( substraitExprName: String, childTransformer: ExpressionTransformer, diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 80c8c6348ae7..7815cbf69ebd 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -243,11 +243,10 @@ object ExpressionConverter extends SQLConfHelper with Logging { d ) case r: Round if r.child.dataType.isInstanceOf[DecimalType] => - BackendsApiManager.getSparkPlanExecApiInstance.genDecimalRoundTransformer( + DecimalRoundTransformer( substraitExprName, replaceWithExpressionTransformerInternal(r.child, attributeSeq, expressionsMap), - r - ) + r) case t: ToUnixTimestamp => BackendsApiManager.getSparkPlanExecApiInstance.genUnixTimestampTransformer( substraitExprName, diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index cf6588c97221..bc0410834dd9 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -849,7 +849,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("atan2") .exclude("round/bround") .exclude("SPARK-37388: width_bucket") - .excludeGlutenTest("round/bround") enableSuite[GlutenMiscExpressionsSuite] enableSuite[GlutenNondeterministicSuite] .exclude("MonotonicallyIncreasingID") diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala index 1a46caa25443..54583547d057 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.expressions +import org.apache.gluten.utils.BackendTestUtils + import org.apache.spark.sql.GlutenTestsTrait import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -70,7 +72,13 @@ class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTr checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow) checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow) - checkEvaluation(BRound(floatPi, scale), bRoundFloatResults(i), EmptyRow) + checkEvaluation( + BRound(floatPi, scale), + // the velox backend will fallback when executing bround, + // so uses the same excepted results with the vanilla spark + if (BackendTestUtils.isCHBackendLoaded()) floatResults(i) else bRoundFloatResults(i), + EmptyRow + ) } val bdResults: Seq[BigDecimal] = Seq( diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 92032bd4385d..6a403204fb7a 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -853,7 +853,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("unhex") .exclude("atan2") .exclude("round/bround/floor/ceil") - .excludeGlutenTest("round/bround/floor/ceil") .exclude("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM") .exclude("SPARK-35926: Support YearMonthIntervalType in width-bucket function") .exclude("SPARK-35925: Support DayTimeIntervalType in width-bucket function") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala index e220924880c7..a60f0dce644b 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.expressions +import org.apache.gluten.utils.BackendTestUtils + import org.apache.spark.sql.GlutenTestsTrait import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -125,7 +127,13 @@ class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTr checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow) checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow) - checkEvaluation(BRound(floatPi, scale), bRoundFloatResults(i), EmptyRow) + checkEvaluation( + BRound(floatPi, scale), + // the velox backend will fallback when executing bround, + // so uses the same excepted results with the vanilla spark + if (BackendTestUtils.isCHBackendLoaded()) floatResults(i) else bRoundFloatResults(i), + EmptyRow + ) checkEvaluation( checkDataTypeAndCast(RoundFloor(Literal(doublePi), Literal(scale))), doubleResultsFloor(i), diff --git a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala index f6034c23bceb..d264bd1acc55 100644 --- a/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala +++ b/shims/spark33/src/main/scala/org/apache/gluten/sql/shims/spark33/Spark33Shims.scala @@ -18,7 +18,7 @@ package org.apache.gluten.sql.shims.spark33 import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects import org.apache.gluten.expression.{ExpressionNames, Sig} -import org.apache.gluten.expression.ExpressionNames.{KNOWN_NULLABLE, TIMESTAMP_ADD} +import org.apache.gluten.expression.ExpressionNames.{CEIL, FLOOR, KNOWN_NULLABLE, TIMESTAMP_ADD} import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims} import org.apache.spark._ @@ -67,7 +67,9 @@ class Spark33Shims extends SparkShims { Sig[Csc](ExpressionNames.CSC), Sig[KnownNullable](KNOWN_NULLABLE), Sig[Empty2Null](ExpressionNames.EMPTY2NULL), - Sig[TimestampAdd](TIMESTAMP_ADD) + Sig[TimestampAdd](TIMESTAMP_ADD), + Sig[RoundFloor](FLOOR), + Sig[RoundCeil](CEIL) ) }