diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index 3696fcfc0fe7..f3cbe7e78d6e 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -80,6 +80,9 @@ class QueryConfig { // truncating the decimal part instead of rounding. static constexpr const char* kCastToIntByTruncate = "cast_to_int_by_truncate"; + // This flags forces to bound the decimal precision. + static constexpr const char* kAllowPrecisionLoss = "allow_precision_loss"; + /// Used for backpressure to block local exchange producers when the local /// exchange buffer reaches or exceeds this size. static constexpr const char* kMaxLocalExchangeBufferSize = @@ -329,6 +332,10 @@ class QueryConfig { return get(kCastToIntByTruncate, false); } + bool isAllowPrecisionLoss() const { + return get(kAllowPrecisionLoss, true); + } + bool codegenEnabled() const { return get(kCodegenEnabled, false); } diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp index cb85c22e7361..afede565f97a 100644 --- a/velox/functions/sparksql/DecimalArithmetic.cpp +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -23,6 +23,12 @@ namespace facebook::velox::functions::sparksql { namespace { +inline static std::pair bounded( + const uint8_t rPrecision, + const uint8_t rScale) { + return {std::min(rPrecision, 38), std::min(rScale, 38)}; +} + inline static std::pair adjustPrecisionScale( const uint8_t rPrecision, const uint8_t rScale) { @@ -385,11 +391,13 @@ class Addition { const uint8_t aPrecision, const uint8_t aScale, const uint8_t bPrecision, - const uint8_t bScale) { + const uint8_t bScale, + const allowPrecisionLoss) { auto precision = std::max(aPrecision - aScale, bPrecision - bScale) + std::max(aScale, bScale) + 1; auto scale = std::max(aScale, bScale); - return adjustPrecisionScale(precision, scale); + return allowPrecisionLoss ? adjustPrecisionScale(precision, scale) + : bounded(precision, scale); } }; @@ -433,7 +441,8 @@ class Subtraction { const uint8_t aPrecision, const uint8_t aScale, const uint8_t bPrecision, - const uint8_t bScale) { + const uint8_t bScale, + const bool allowPrecisionLoss) { return Addition::computeResultPrecisionScale( aPrecision, aScale, bPrecision, bScale); } @@ -539,8 +548,11 @@ class Multiply { const uint8_t aPrecision, const uint8_t aScale, const uint8_t bPrecision, - const uint8_t bScale) { - return adjustPrecisionScale(aPrecision + bPrecision + 1, aScale + bScale); + const uint8_t bScale, + const bool allowPrecisionLoss) { + return allowPrecisionLoss + ? adjustPrecisionScale(aPrecision + bPrecision + 1, aScale + bScale) + : bounded(aPrecision + bPrecision + 1, aScale + bScale) } private: @@ -591,10 +603,22 @@ class Divide { const uint8_t aPrecision, const uint8_t aScale, const uint8_t bPrecision, - const uint8_t bScale) { - auto scale = std::max(6, aScale + bPrecision + 1); - auto precision = aPrecision - aScale + bScale + scale; - return adjustPrecisionScale(precision, scale); + const uint8_t bScale, + const bool allowPrecisionLoss) { + if (allowPrecisionLoss) { + auto scale = std::max(6, aScale + bPrecision + 1); + auto precision = aPrecision - aScale + bScale + scale; + return adjustPrecisionScale(precision, scale); + } else { + auto intDig = std::min(38, aPrecision - aScale + bScale); + auto decDig = + std::min(38, std::max(6, aScale + bPrecision + 1)) auto diff = + (intDig + decDig) - 38; + if (diff > 0) { + decDig -= diff / 2 + 1 intDig = 38 - decDig + } + return bounded(intDig + decDig, decDig); + } } }; @@ -664,13 +688,14 @@ template std::shared_ptr createDecimalFunction( const std::string& name, const std::vector& inputArgs, - const core::QueryConfig& /*config*/) { + const core::QueryConfig& config) { auto aType = inputArgs[0].type; auto bType = inputArgs[1].type; auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType); auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType); + const bool allowPrecisionLoss = config.isAllowPrecisionLoss(); auto [rPrecision, rScale] = Operation::computeResultPrecisionScale( - aPrecision, aScale, bPrecision, bScale); + aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss); uint8_t aRescale = Operation::computeRescaleFactor(aScale, bScale, rScale); uint8_t bRescale = Operation::computeRescaleFactor(bScale, aScale, rScale); if (aType->isShortDecimal()) {