From 67fa0192b29b90c14672f9abff083686da190b5a Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Fri, 25 Aug 2023 09:44:49 +0800 Subject: [PATCH] support allowDecimalPrecisionLoss Signed-off-by: Yuan Zhou --- cpp/core/config/GlutenConfig.h | 2 ++ cpp/velox/compute/WholeStageResultIterator.cc | 2 ++ ep/build-velox/src/get_velox.sh | 4 +-- .../expression/ExpressionConverter.scala | 8 ----- .../gluten/utils/DecimalArithmeticUtil.scala | 29 +++++++++++++++++-- .../org/apache/gluten/GlutenConfig.scala | 1 + 6 files changed, 33 insertions(+), 13 deletions(-) diff --git a/cpp/core/config/GlutenConfig.h b/cpp/core/config/GlutenConfig.h index 3c47fb5479bd..cf34b6a72c80 100644 --- a/cpp/core/config/GlutenConfig.h +++ b/cpp/core/config/GlutenConfig.h @@ -34,6 +34,8 @@ const std::string kLegacySize = "spark.sql.legacy.sizeOfNull"; const std::string kSessionTimezone = "spark.sql.session.timeZone"; +const std::string kAllowPrecisionLoss = "spark.sql.decimalOperations.allowPrecisionLoss"; + const std::string kIgnoreMissingFiles = "spark.sql.files.ignoreMissingFiles"; const std::string kDefaultSessionTimezone = "spark.gluten.sql.session.timeZone.default"; diff --git a/cpp/velox/compute/WholeStageResultIterator.cc b/cpp/velox/compute/WholeStageResultIterator.cc index 83749061c1b8..a142023ad7a8 100644 --- a/cpp/velox/compute/WholeStageResultIterator.cc +++ b/cpp/velox/compute/WholeStageResultIterator.cc @@ -490,6 +490,8 @@ std::unordered_map WholeStageResultIterator::getQueryC } // Adjust timestamp according to the above configured session timezone. configs[velox::core::QueryConfig::kAdjustTimestampToTimezone] = "true"; + // To align with Spark's behavior, allow decimal precision loss or not. + configs[velox::core::QueryConfig::kAllowPrecisionLoss] = veloxCfg_->get(kAllowPrecisionLoss, "true"); // Align Velox size function with Spark. configs[velox::core::QueryConfig::kSparkLegacySizeOfNull] = std::to_string(veloxCfg_->get(kLegacySize, true)); diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index c26bedd5e9af..0ba3ac2ea2fc 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -16,8 +16,8 @@ set -exu -VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_05_06 +VELOX_REPO=https://github.com/zhouyuan/velox.git +VELOX_BRANCH=wip_decimal_precision_loss VELOX_HOME="" #Set on run gluten on HDFS diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 7815cbf69ebd..f22f12c3ef87 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -511,14 +511,6 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), expr) case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) => - // PrecisionLoss=true: velox support / ch not support - // PrecisionLoss=false: velox not support / ch support - // TODO ch support PrecisionLoss=true - if (!BackendsApiManager.getSettings.allowDecimalArithmetic) { - throw new GlutenNotSupportException( - s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} " + - s"${conf.decimalOperationsAllowPrecisionLoss} mode") - } val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) { DecimalArithmeticUtil.rescaleLiteral(b) } else { diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala index 621dcc061ec7..cf914a5cc010 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala @@ -22,6 +22,7 @@ import org.apache.gluten.expression.{CheckOverflowTransformer, ChildTransformer, import org.apache.spark.sql.catalyst.analysis.DecimalPrecision import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, Subtract} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType} object DecimalArithmeticUtil { @@ -33,12 +34,14 @@ object DecimalArithmeticUtil { val MIN_ADJUSTED_SCALE = 6 val MAX_PRECISION = 38 + val MAX_SCALE = 38 // Returns the result decimal type of a decimal arithmetic computing. def getResultTypeForOperation( operationType: OperationType.Config, type1: DecimalType, type2: DecimalType): DecimalType = { + val allowPrecisionLoss = SQLConf.get.decimalOperationsAllowPrecisionLoss var resultScale = 0 var resultPrecision = 0 operationType match { @@ -54,8 +57,20 @@ object DecimalArithmeticUtil { resultScale = type1.scale + type2.scale resultPrecision = type1.precision + type2.precision + 1 case OperationType.DIVIDE => - resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + type2.precision + 1) - resultPrecision = type1.precision - type1.scale + type2.scale + resultScale + if (allowPrecisionLoss) { + resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + type2.precision + 1) + resultPrecision = type1.precision - type1.scale + type2.scale + resultScale + } else { + var intDig = Math.min(MAX_SCALE, type1.precision - type1.scale + type2.scale) + var decDig = Math.min(MAX_SCALE, Math.max(6, type1.scale + type2.precision + 1)) + val diff = (intDig + decDig) - MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = MAX_SCALE - decDig + } + resultScale = intDig + decDig + resultPrecision = decDig + } case OperationType.MOD => resultScale = Math.max(type1.scale, type2.scale) resultPrecision = @@ -63,7 +78,11 @@ object DecimalArithmeticUtil { case other => throw new GlutenNotSupportException(s"$other is not supported.") } - adjustScaleIfNeeded(resultPrecision, resultScale) + if (allowPrecisionLoss) { + adjustScaleIfNeeded(resultPrecision, resultScale) + } else { + bounded(resultPrecision, resultScale) + } } // Returns the adjusted decimal type when the precision is larger the maximum. @@ -79,6 +98,10 @@ object DecimalArithmeticUtil { DecimalType(typePrecision, typeScale) } + def bounded(precision: Int, scale: Int): DecimalType = { + DecimalType(Math.min(precision, MAX_PRECISION), Math.min(scale, MAX_SCALE)) + } + // If casting between DecimalType, unnecessary cast is skipped to avoid data loss, // because argument input type of "cast" is actually the res type of "+-*/". // Cast will use a wider input type, then calculates result type with less scale than expected. diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 3c7ddf32c71f..414eb189067b 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -571,6 +571,7 @@ object GlutenConfig { GLUTEN_DEFAULT_SESSION_TIMEZONE_KEY, SQLConf.LEGACY_SIZE_OF_NULL.key, "spark.io.compression.codec", + "spark.sql.decimalOperations.allowPrecisionLoss", COLUMNAR_VELOX_BLOOM_FILTER_EXPECTED_NUM_ITEMS.key, COLUMNAR_VELOX_BLOOM_FILTER_NUM_BITS.key, COLUMNAR_VELOX_BLOOM_FILTER_MAX_NUM_BITS.key,