Skip to content

Commit

Permalink
[GLUTEN-5512][CH] Fix the incorrect transformer for the round functio…
Browse files Browse the repository at this point in the history
…n with the decimal data type (#5513)

[CH] Fix the incorrect transformer for the round function with the decimal data type.

When transforming the round function with the decimal data type, it only transformers the `child` expression, but missing the `scale` expression, which will lead to the incorrect results when executing the `round(decimal_data, x)`.

Close #5512.
  • Loading branch information
zzcclp authored Apr 24, 2024
1 parent a683e31 commit 89015c9
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.execution

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.{DataFrame, Row, TestUtils}
import org.apache.spark.sql.execution.InputIteratorTransformer
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
Expand Down Expand Up @@ -582,11 +582,18 @@ class GlutenClickHouseTPCHBucketSuite
assert(plans.size == expectedCount)
}

def checkResult(df: DataFrame, exceptedResult: Array[Row]): Unit = {
def checkResult(df: DataFrame, exceptedResult: Seq[Row]): Unit = {
// check the result
val result = df.collect()
assert(result.size == exceptedResult.size)
result.equals(exceptedResult)
val sortedRes = result.map {
s =>
Row.fromSeq(s.toSeq.map {
case a: mutable.WrappedArray[_] => a.sortBy(_.toString.toInt)
case o => o
})
}
TestUtils.compareAnswers(sortedRes, exceptedResult)
}

val SQL =
Expand All @@ -600,10 +607,10 @@ class GlutenClickHouseTPCHBucketSuite
checkResult(
df,
Array(
Row(1, "N", mutable.WrappedArray.make(Array(3, 6, 1, 5, 2, 4))),
Row(1, "N", mutable.WrappedArray.make(Array(1, 2, 3, 4, 5, 6))),
Row(2, "N", mutable.WrappedArray.make(Array(1))),
Row(3, "A", mutable.WrappedArray.make(Array(6, 4, 3))),
Row(3, "R", mutable.WrappedArray.make(Array(2, 5, 1))),
Row(3, "A", mutable.WrappedArray.make(Array(3, 4, 6))),
Row(3, "R", mutable.WrappedArray.make(Array(1, 2, 5))),
Row(4, "N", mutable.WrappedArray.make(Array(1)))
)
)
Expand Down Expand Up @@ -645,11 +652,11 @@ class GlutenClickHouseTPCHBucketSuite
checkResult(
df,
Array(
Row("A", 3, mutable.WrappedArray.make(Array(6, 4, 3))),
Row("A", 3, mutable.WrappedArray.make(Array(3, 4, 6))),
Row("A", 5, mutable.WrappedArray.make(Array(3))),
Row("A", 6, mutable.WrappedArray.make(Array(1))),
Row("A", 33, mutable.WrappedArray.make(Array(1, 2, 3))),
Row("A", 37, mutable.WrappedArray.make(Array(2, 3, 1)))
Row("A", 37, mutable.WrappedArray.make(Array(1, 2, 3)))
)
)
checkHashAggregateCount(df, 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.utils.UTSystemParameters

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.{DataFrame, Row, TestUtils}
import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -495,13 +495,54 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
}

test("test round issue: https://github.com/oap-project/gluten/issues/3462") {
runQueryAndCompare(
"select round(0.41875d * id , 4) from range(10);"
)(checkGlutenOperatorMatch[ProjectExecTransformer])
def checkResult(df: DataFrame, exceptedResult: Seq[Row]): Unit = {
// check the result
val result = df.collect()
assert(result.size == exceptedResult.size)
TestUtils.compareAnswers(result, exceptedResult)
}

runQueryAndCompare(
"select round(0.41875f * id , 4) from range(10);"
)(checkGlutenOperatorMatch[ProjectExecTransformer])
runSql("select round(0.41875d * id , 4) from range(10);")(
df => {
checkGlutenOperatorMatch[ProjectExecTransformer](df)

checkResult(
df,
Seq(
Row(0.0),
Row(0.4188),
Row(0.8375),
Row(1.2563),
Row(1.675),
Row(2.0938),
Row(2.5125),
Row(2.9313),
Row(3.35),
Row(3.7688)
)
)
})

runSql("select round(0.41875f * id , 4) from range(10);")(
df => {
checkGlutenOperatorMatch[ProjectExecTransformer](df)

checkResult(
df,
Seq(
Row(0.0f),
Row(0.4188f),
Row(0.8375f),
Row(1.2562f),
Row(1.675f),
Row(2.0938f),
Row(2.5125f),
Row(2.9312f),
Row(3.35f),
Row(3.7688f)
)
)
})
}

test("test date comparision expression override") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayExists, ArrayFilter, ArrayForAll, ArrayTransform, Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, LambdaFunction, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim, TryEval, Uuid}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
Expand Down Expand Up @@ -572,14 +572,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
* * Expressions.
*/

/** Generates a transformer for decimal round. */
override def genDecimalRoundTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Round): ExpressionTransformer = {
DecimalRoundTransformer(substraitExprName, child, original)
}

/** Generate StringSplit transformer. */
override def genStringSplitTransformer(
substraitExprName: String,
Expand Down
81 changes: 63 additions & 18 deletions cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,77 @@

#include <Functions/FunctionsRound.h>


namespace local_engine
{
using namespace DB;

template <typename T>
class BaseFloatRoundingHalfUpComputation
class BaseFloatRoundingHalfUpComputation;

template <>
class BaseFloatRoundingHalfUpComputation<Float32>
{
public:
using ScalarType = Float32;
using VectorType = __m128;
static const size_t data_count = 4;

static VectorType load(const ScalarType * in) { return _mm_loadu_ps(in); }
static VectorType load1(const ScalarType in) { return _mm_load1_ps(&in); }
static void store(ScalarType * out, VectorType val) { _mm_storeu_ps(out, val);}
static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_ps(val, scale); }
static VectorType divide(VectorType val, VectorType scale) { return _mm_div_ps(val, scale); }
template <RoundingMode mode> static VectorType apply(VectorType val)
{
ScalarType tempFloatsIn[data_count];
ScalarType tempFloatsOut[data_count];
store(tempFloatsIn, val);
for (size_t i = 0; i < data_count; ++i)
tempFloatsOut[i] = std::roundf(tempFloatsIn[i]);

return load(tempFloatsOut);
}

static VectorType prepare(size_t scale)
{
return load1(scale);
}
};

template <>
class BaseFloatRoundingHalfUpComputation<Float64>
{
public:
using ScalarType = T;
using VectorType = Float64;
static const size_t data_count = 1;

static VectorType load(const ScalarType * in) { return static_cast<VectorType>(*in); }
static VectorType load1(ScalarType in) { return in; }
static ScalarType store(ScalarType * out, VectorType val) { return *out = static_cast<ScalarType>(val); }
static VectorType multiply(VectorType val, VectorType scale) { return val * scale; }
static VectorType divide(VectorType val, VectorType scale) { return val / scale; }
static VectorType apply(VectorType val) { return round(val); }
static VectorType prepare(size_t scale) { return load1(scale); }
using ScalarType = Float64;
using VectorType = __m128d;
static const size_t data_count = 2;

static VectorType load(const ScalarType * in) { return _mm_loadu_pd(in); }
static VectorType load1(const ScalarType in) { return _mm_load1_pd(&in); }
static void store(ScalarType * out, VectorType val) { _mm_storeu_pd(out, val);}
static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_pd(val, scale); }
static VectorType divide(VectorType val, VectorType scale) { return _mm_div_pd(val, scale); }
template <RoundingMode mode> static VectorType apply(VectorType val)
{
ScalarType tempFloatsIn[data_count];
ScalarType tempFloatsOut[data_count];
store(tempFloatsIn, val);
for (size_t i = 0; i < data_count; ++i)
tempFloatsOut[i] = std::round(tempFloatsIn[i]);

return load(tempFloatsOut);
}

static VectorType prepare(size_t scale)
{
return load1(scale);
}
};


/** Implementation of low-level round-off functions for floating-point values.
*/
template <typename T, ScaleMode scale_mode>
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation<T>
{
using Base = BaseFloatRoundingHalfUpComputation<T>;
Expand All @@ -58,7 +103,7 @@ class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation
else if (scale_mode == ScaleMode::Negative)
val = Base::divide(val, scale);

val = Base::apply(val);
val = Base::template apply<rounding_mode>(val);

if (scale_mode == ScaleMode::Positive)
val = Base::divide(val, scale);
Expand All @@ -72,13 +117,13 @@ class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation

/** Implementing high-level rounding functions.
*/
template <typename T, ScaleMode scale_mode>
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
struct FloatRoundingHalfUpImpl
{
private:
static_assert(!is_decimal<T>);

using Op = FloatRoundingHalfUpComputation<T, scale_mode>;
using Op = FloatRoundingHalfUpComputation<T, rounding_mode, scale_mode>;
using Data = std::array<T, Op::data_count>;
using ColumnType = ColumnVector<T>;
using Container = typename ColumnType::Container;
Expand Down Expand Up @@ -125,7 +170,7 @@ struct DispatcherRoundingHalfUp
{
template <ScaleMode scale_mode>
using FunctionRoundingImpl = std::conditional_t<std::is_floating_point_v<T>,
FloatRoundingHalfUpImpl<T, scale_mode>,
FloatRoundingHalfUpImpl<T, rounding_mode, scale_mode>,
IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>;

static ColumnPtr apply(const IColumn * col_general, Scale scale_arg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,6 @@ trait SparkPlanExecApi {
*/
def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]]

def genDecimalRoundTransformer(
substraitExprName: String,
child: ExpressionTransformer,
original: Round): ExpressionTransformer = {
GenericExpressionTransformer(substraitExprName, Seq(child), original)
}

def genGetStructFieldTransformer(
substraitExprName: String,
childTransformer: ExpressionTransformer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,10 @@ object ExpressionConverter extends SQLConfHelper with Logging {
d
)
case r: Round if r.child.dataType.isInstanceOf[DecimalType] =>
BackendsApiManager.getSparkPlanExecApiInstance.genDecimalRoundTransformer(
DecimalRoundTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(r.child, attributeSeq, expressionsMap),
r
)
r)
case t: ToUnixTimestamp =>
BackendsApiManager.getSparkPlanExecApiInstance.genUnixTimestampTransformer(
substraitExprName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("atan2")
.exclude("round/bround")
.exclude("SPARK-37388: width_bucket")
.excludeGlutenTest("round/bround")
enableSuite[GlutenMiscExpressionsSuite]
enableSuite[GlutenNondeterministicSuite]
.exclude("MonotonicallyIncreasingID")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.expressions

import org.apache.gluten.utils.BackendTestUtils

import org.apache.spark.sql.GlutenTestsTrait
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -70,7 +72,13 @@ class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTr
checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow)
checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow)
checkEvaluation(BRound(floatPi, scale), bRoundFloatResults(i), EmptyRow)
checkEvaluation(
BRound(floatPi, scale),
// the velox backend will fallback when executing bround,
// so uses the same excepted results with the vanilla spark
if (BackendTestUtils.isCHBackendLoaded()) floatResults(i) else bRoundFloatResults(i),
EmptyRow
)
}

val bdResults: Seq[BigDecimal] = Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("unhex")
.exclude("atan2")
.exclude("round/bround/floor/ceil")
.excludeGlutenTest("round/bround/floor/ceil")
.exclude("SPARK-36922: Support ANSI intervals for SIGN/SIGNUM")
.exclude("SPARK-35926: Support YearMonthIntervalType in width-bucket function")
.exclude("SPARK-35925: Support DayTimeIntervalType in width-bucket function")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.expressions

import org.apache.gluten.utils.BackendTestUtils

import org.apache.spark.sql.GlutenTestsTrait
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -125,7 +127,13 @@ class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTr
checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow)
checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow)
checkEvaluation(BRound(floatPi, scale), bRoundFloatResults(i), EmptyRow)
checkEvaluation(
BRound(floatPi, scale),
// the velox backend will fallback when executing bround,
// so uses the same excepted results with the vanilla spark
if (BackendTestUtils.isCHBackendLoaded()) floatResults(i) else bRoundFloatResults(i),
EmptyRow
)
checkEvaluation(
checkDataTypeAndCast(RoundFloor(Literal(doublePi), Literal(scale))),
doubleResultsFloor(i),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.gluten.sql.shims.spark33

import org.apache.gluten.execution.datasource.GlutenParquetWriterInjects
import org.apache.gluten.expression.{ExpressionNames, Sig}
import org.apache.gluten.expression.ExpressionNames.{KNOWN_NULLABLE, TIMESTAMP_ADD}
import org.apache.gluten.expression.ExpressionNames.{CEIL, FLOOR, KNOWN_NULLABLE, TIMESTAMP_ADD}
import org.apache.gluten.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark._
Expand Down Expand Up @@ -67,7 +67,9 @@ class Spark33Shims extends SparkShims {
Sig[Csc](ExpressionNames.CSC),
Sig[KnownNullable](KNOWN_NULLABLE),
Sig[Empty2Null](ExpressionNames.EMPTY2NULL),
Sig[TimestampAdd](TIMESTAMP_ADD)
Sig[TimestampAdd](TIMESTAMP_ADD),
Sig[RoundFloor](FLOOR),
Sig[RoundCeil](CEIL)
)
}

Expand Down

0 comments on commit 89015c9

Please sign in to comment.