From a1f3198195454ff364b8f1f4b00b9d04a32355df Mon Sep 17 00:00:00 2001 From: Arnav Balyan <60175178+ArnavBalyan@users.noreply.github.com> Date: Fri, 9 Aug 2024 18:57:11 +0530 Subject: [PATCH] [VL] Fix high precision rounding (#6707) --- cpp/velox/operators/functions/Arithmetic.h | 11 +++++++---- .../expressions/GlutenMathExpressionsSuite.scala | 3 +++ .../expressions/GlutenMathExpressionsSuite.scala | 4 ++++ .../expressions/GlutenMathExpressionsSuite.scala | 3 +++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/cpp/velox/operators/functions/Arithmetic.h b/cpp/velox/operators/functions/Arithmetic.h index 0474e1554981..7b4c9ae9db7c 100644 --- a/cpp/velox/operators/functions/Arithmetic.h +++ b/cpp/velox/operators/functions/Arithmetic.h @@ -17,6 +17,7 @@ #include #include #include +#include #include namespace gluten { @@ -38,14 +39,16 @@ struct RoundFunction { return number; } - double factor = std::pow(10, decimals); + // Using long double for high precision during intermediate calculations. + // TODO: Make this more efficient with Boost to support high arbitrary precision at runtime. + long double factor = std::pow(10.0L, static_cast(decimals)); static const TNum kInf = std::numeric_limits::infinity(); + if (number < 0) { - return (std::round(std::nextafter(number, -kInf) * factor * -1) / factor) * -1; + return static_cast((std::round(std::nextafter(number, -kInf) * factor * -1) / factor) * -1); } - return std::round(std::nextafter(number, kInf) * factor) / factor; + return static_cast(std::round(std::nextafter(number, kInf) * factor) / factor); } - template FOLLY_ALWAYS_INLINE void call(TInput& result, const TInput& a, const int32_t b = 0) { result = round(a, b); diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala index 54583547d057..765a64f91baf 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala @@ -121,6 +121,9 @@ class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTr checkEvaluation(Round(-3.5, 0), -4.0) checkEvaluation(Round(-0.35, 1), -0.4) checkEvaluation(Round(-35, -1), -40) + checkEvaluation(Round(1.12345678901234567, 8), 1.12345679) + checkEvaluation(Round(-0.98765432109876543, 5), -0.98765) + checkEvaluation(Round(12345.67890123456789, 6), 12345.678901) checkEvaluation(BRound(2.5, 0), 2.0) checkEvaluation(BRound(3.5, 0), 4.0) checkEvaluation(BRound(-2.5, 0), -2.0) diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala index a60f0dce644b..122f8dc066af 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala @@ -249,6 +249,10 @@ class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTr checkEvaluation(Round(-3.5, 0), -4.0) checkEvaluation(Round(-0.35, 1), -0.4) checkEvaluation(Round(-35, -1), -40) + checkEvaluation(Round(1.12345678901234567, 8), 1.12345679) + checkEvaluation(Round(-0.98765432109876543, 5), -0.98765) + checkEvaluation(Round(12345.67890123456789, 6), 12345.678901) + checkEvaluation(Round(-35, -1), -40) checkEvaluation(Round(BigDecimal("45.00"), -1), BigDecimal(50)) checkEvaluation(BRound(2.5, 0), 2.0) checkEvaluation(BRound(3.5, 0), 4.0) diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala index e220924880c7..7308352e40c6 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/catalyst/expressions/GlutenMathExpressionsSuite.scala @@ -248,6 +248,9 @@ class GlutenMathExpressionsSuite extends MathExpressionsSuite with GlutenTestsTr checkEvaluation(BRound(-3.5, 0), -4.0) checkEvaluation(BRound(-0.35, 1), -0.4) checkEvaluation(BRound(-35, -1), -40) + checkEvaluation(Round(1.12345678901234567, 8), 1.12345679) + checkEvaluation(Round(-0.98765432109876543, 5), -0.98765) + checkEvaluation(Round(12345.67890123456789, 6), 12345.678901) checkEvaluation(BRound(BigDecimal("45.00"), -1), BigDecimal(40)) checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(2.5), Literal(0))), Decimal(2)) checkEvaluation(checkDataTypeAndCast(RoundFloor(Literal(3.5), Literal(0))), Decimal(3))