From 8044475dd0f18e350c05aa7614c4e2fe9fbe1bf9 Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Fri, 25 Aug 2023 09:05:40 +0800 Subject: [PATCH] adding param allowPrecisionLoss Signed-off-by: Yuan Zhou --- velox/core/QueryConfig.h | 7 +++ .../functions/sparksql/DecimalArithmetic.cpp | 46 +++++++++++++------ velox/functions/sparksql/DecimalUtil.h | 10 ++++ 3 files changed, 50 insertions(+), 13 deletions(-) diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index 7f2b110331d99..8db1aa8718d64 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -101,6 +101,9 @@ class QueryConfig { static constexpr const char* kCastMatchStructByName = "cast_match_struct_by_name"; + // This flags forces to bound the decimal precision. + static constexpr const char* kAllowPrecisionLoss = "allow_precision_loss"; + /// Used for backpressure to block local exchange producers when the local /// exchange buffer reaches or exceeds this size. static constexpr const char* kMaxLocalExchangeBufferSize = @@ -496,6 +499,10 @@ class QueryConfig { return get(kCastMatchStructByName, false); } + bool isAllowPrecisionLoss() const { + return get(kAllowPrecisionLoss, true); + } + bool codegenEnabled() const { return get(kCodegenEnabled, false); } diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp index 519ff0d5cf588..ee1c17f84fccc 100644 --- a/velox/functions/sparksql/DecimalArithmetic.cpp +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -416,11 +416,14 @@ class Addition { uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, - uint8_t bScale) { + uint8_t bScale, + bool allowPrecisionLoss) { auto precision = std::max(aPrecision - aScale, bPrecision - bScale) + std::max(aScale, bScale) + 1; auto scale = std::max(aScale, bScale); - return DecimalUtil::adjustPrecisionScale(precision, scale); + return allowPrecisionLoss + ? DecimalUtil::adjustPrecisionScale(precision, scale) + : DecimalUtil::bounded(precision, scale); } }; @@ -464,9 +467,10 @@ class Subtraction { uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, - uint8_t bScale) { + uint8_t bScale, + bool allowPrecisionLoss) { return Addition::computeResultPrecisionScale( - aPrecision, aScale, bPrecision, bScale); + aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss); } }; @@ -566,9 +570,12 @@ class Multiply { uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, - uint8_t bScale) { - return DecimalUtil::adjustPrecisionScale( - aPrecision + bPrecision + 1, aScale + bScale); + uint8_t bScale, + const bool allowPrecisionLoss) { + return allowPrecisionLoss + ? DecimalUtil::adjustPrecisionScale( + aPrecision + bPrecision + 1, aScale + bScale) + : DecimalUtil::bounded(aPrecision + bPrecision + 1, aScale + bScale); } private: @@ -616,10 +623,22 @@ class Divide { uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, - uint8_t bScale) { - auto scale = std::max(6, aScale + bPrecision + 1); - auto precision = aPrecision - aScale + bScale + scale; - return DecimalUtil::adjustPrecisionScale(precision, scale); + uint8_t bScale, + bool allowPrecisionLoss) { + if (allowPrecisionLoss) { + auto scale = std::max(6, aScale + bPrecision + 1); + auto precision = aPrecision - aScale + bScale + scale; + return DecimalUtil::adjustPrecisionScale(precision, scale); + } else { + auto intDig = std::min(38, aPrecision - aScale + bScale); + auto decDig = std::min(38, std::max(6, aScale + bPrecision + 1)); + auto diff = (intDig + decDig) - 38; + if (diff > 0) { + decDig -= diff / 2 + 1; + intDig = 38 - decDig; + } + return DecimalUtil::bounded(intDig + decDig, decDig); + } } }; @@ -689,13 +708,14 @@ template std::shared_ptr createDecimalFunction( const std::string& name, const std::vector& inputArgs, - const core::QueryConfig& /*config*/) { + const core::QueryConfig& config) { const auto& aType = inputArgs[0].type; const auto& bType = inputArgs[1].type; const auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); const auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); + const bool allowPrecisionLoss = config.isAllowPrecisionLoss(); const auto [rPrecision, rScale] = Operation::computeResultPrecisionScale( - aPrecision, aScale, bPrecision, bScale); + aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss); const uint8_t aRescale = Operation::computeRescaleFactor(aScale, bScale, rScale); const uint8_t bRescale = diff --git a/velox/functions/sparksql/DecimalUtil.h b/velox/functions/sparksql/DecimalUtil.h index fbe5da77809ec..c8b58a5cfa000 100644 --- a/velox/functions/sparksql/DecimalUtil.h +++ b/velox/functions/sparksql/DecimalUtil.h @@ -46,6 +46,16 @@ class DecimalUtil { } } + /// This method is used when + /// `spark.sql.decimalOperations.allowPrecisionLoss` is set to false. + inline static std::pair bounded( + uint8_t rPrecision, + uint8_t rScale) { + return { + std::min(static_cast(rPrecision), 38), + std::min(static_cast(rScale), 38)}; + } + /// @brief Convert int256 value to int64 or int128, set overflow to true if /// value cannot convert to specific type. /// @return The converted value.