Skip to content

Commit

Permalink
adding param allowPrecisionLoss
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Zhou <[email protected]>
  • Loading branch information
zhouyuan committed Aug 25, 2023
1 parent 0034e47 commit 74a7a98
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 11 deletions.
7 changes: 7 additions & 0 deletions velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class QueryConfig {
// truncating the decimal part instead of rounding.
static constexpr const char* kCastToIntByTruncate = "cast_to_int_by_truncate";

// This flags forces to bound the decimal precision.
static constexpr const char* kAllowPrecisionLoss = "allow_precision_loss";

/// Used for backpressure to block local exchange producers when the local
/// exchange buffer reaches or exceeds this size.
static constexpr const char* kMaxLocalExchangeBufferSize =
Expand Down Expand Up @@ -329,6 +332,10 @@ class QueryConfig {
return get<bool>(kCastToIntByTruncate, false);
}

bool isAllowPrecisionLoss() const {
return get<bool>(kAllowPrecisionLoss, true);
}

bool codegenEnabled() const {
return get<bool>(kCodegenEnabled, false);
}
Expand Down
47 changes: 36 additions & 11 deletions velox/functions/sparksql/DecimalArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
namespace facebook::velox::functions::sparksql {
namespace {

inline static std::pair<uint8_t, uint8_t> bounded(
const uint8_t rPrecision,
const uint8_t rScale) {
return {std::min(rPrecision, 38), std::min(rScale, 38)};
}

inline static std::pair<uint8_t, uint8_t> adjustPrecisionScale(
const uint8_t rPrecision,
const uint8_t rScale) {
Expand Down Expand Up @@ -385,11 +391,13 @@ class Addition {
const uint8_t aPrecision,
const uint8_t aScale,
const uint8_t bPrecision,
const uint8_t bScale) {
const uint8_t bScale,
const allowPrecisionLoss) {
auto precision = std::max(aPrecision - aScale, bPrecision - bScale) +
std::max(aScale, bScale) + 1;
auto scale = std::max(aScale, bScale);
return adjustPrecisionScale(precision, scale);
return allowPrecisionLoss ? adjustPrecisionScale(precision, scale)
: bounded(precision, scale);
}
};

Expand Down Expand Up @@ -433,7 +441,8 @@ class Subtraction {
const uint8_t aPrecision,
const uint8_t aScale,
const uint8_t bPrecision,
const uint8_t bScale) {
const uint8_t bScale,
const bool allowPrecisionLoss) {
return Addition::computeResultPrecisionScale(
aPrecision, aScale, bPrecision, bScale);
}
Expand Down Expand Up @@ -539,8 +548,11 @@ class Multiply {
const uint8_t aPrecision,
const uint8_t aScale,
const uint8_t bPrecision,
const uint8_t bScale) {
return adjustPrecisionScale(aPrecision + bPrecision + 1, aScale + bScale);
const uint8_t bScale,
const bool allowPrecisionLoss) {
return allowPrecisionLoss
? adjustPrecisionScale(aPrecision + bPrecision + 1, aScale + bScale)
: bounded(aPrecision + bPrecision + 1, aScale + bScale)
}

private:
Expand Down Expand Up @@ -591,10 +603,22 @@ class Divide {
const uint8_t aPrecision,
const uint8_t aScale,
const uint8_t bPrecision,
const uint8_t bScale) {
auto scale = std::max(6, aScale + bPrecision + 1);
auto precision = aPrecision - aScale + bScale + scale;
return adjustPrecisionScale(precision, scale);
const uint8_t bScale,
const bool allowPrecisionLoss) {
if (allowPrecisionLoss) {
auto scale = std::max(6, aScale + bPrecision + 1);
auto precision = aPrecision - aScale + bScale + scale;
return adjustPrecisionScale(precision, scale);
} else {
auto intDig = std::min(38, aPrecision - aScale + bScale);
auto decDig =
std::min(38, std::max(6, aScale + bPrecision + 1)) auto diff =
(intDig + decDig) - 38;
if (diff > 0) {
decDig -= diff / 2 + 1 intDig = 38 - decDig
}
return bounded(intDig + decDig, decDig);
}
}
};

Expand Down Expand Up @@ -664,13 +688,14 @@ template <typename Operation>
std::shared_ptr<exec::VectorFunction> createDecimalFunction(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
const core::QueryConfig& config) {
auto aType = inputArgs[0].type;
auto bType = inputArgs[1].type;
auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
const bool allowPrecisionLoss = config.isAllowPrecisionLoss();
auto [rPrecision, rScale] = Operation::computeResultPrecisionScale(
aPrecision, aScale, bPrecision, bScale);
aPrecision, aScale, bPrecision, bScale, allowPrecisionLoss);
uint8_t aRescale = Operation::computeRescaleFactor(aScale, bScale, rScale);
uint8_t bRescale = Operation::computeRescaleFactor(bScale, aScale, rScale);
if (aType->isShortDecimal()) {
Expand Down

0 comments on commit 74a7a98

Please sign in to comment.