Skip to content

Commit

Permalink
Support decimal allow precision loss
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed Aug 30, 2024
1 parent bf1d097 commit e404c13
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,15 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
} else if (SparkShimLoader.getSparkShims.withAnsiEvalMode(original)) {
throw new GlutenNotSupportException(s"$substraitExprName with ansi mode is not supported")
} else {
GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
if (
left.dataType.isInstanceOf[DecimalType] && right.dataType
.isInstanceOf[DecimalType] && !SQLConf.get.decimalOperationsAllowPrecisionLoss
) {
val newName = "not_allow_precision_loss_"
GenericExpressionTransformer(newName, Seq(left, right), original)
} else {
GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
}
}
}

Expand Down
4 changes: 0 additions & 4 deletions cpp/velox/compute/WholeStageResultIterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,6 @@ std::unordered_map<std::string, std::string> WholeStageResultIterator::getQueryC
// Adjust timestamp according to the above configured session timezone.
configs[velox::core::QueryConfig::kAdjustTimestampToTimezone] = "true";

// To align with Spark's behavior, allow decimal precision loss or not.
configs[velox::core::QueryConfig::kSparkDecimalOperationsAllowPrecisionLoss] =
veloxCfg_->get<std::string>(kAllowPrecisionLoss, "true");

{
// partial aggregation memory config
auto offHeapMemory = veloxCfg_->get<int64_t>(kSparkTaskOffHeapMemory, facebook::velox::memory::kMaxMemory);
Expand Down
10 changes: 10 additions & 0 deletions cpp/velox/operators/functions/RegistrationAllFunctions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
#include "velox/functions/prestosql/window/WindowFunctionsRegistration.h"
#include "velox/functions/sparksql/DecimalArithmetic.h"
#include "velox/functions/sparksql/Hash.h"
#include "velox/functions/sparksql/Rand.h"
#include "velox/functions/sparksql/Register.h"
Expand Down Expand Up @@ -74,6 +75,14 @@ void registerFunctionOverwrite() {

velox::functions::registerPrestoVectorFunctions();
}

void registerFunctionForConfig() {
const std::string prefix = "not_allow_precision_loss_";
velox::functions::sparksql::registerDecimalAdd(prefix, false);
velox::functions::sparksql::registerDecimalSubtract(prefix, false);
velox::functions::sparksql::registerDecimalMultiply(prefix, false);
velox::functions::sparksql::registerDecimalDivide(prefix, false);
}
} // namespace

void registerAllFunctions() {
Expand All @@ -87,6 +96,7 @@ void registerAllFunctions() {
// Using function overwrite to handle function names mismatch between Spark
// and Velox.
registerFunctionOverwrite();
registerFunctionForConfig();
}

} // namespace gluten
2 changes: 1 addition & 1 deletion ep/build-velox/src/get_velox.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
set -exu

VELOX_REPO=https://github.com/oap-project/velox.git
VELOX_BRANCH=2024_08_28
VELOX_BRANCH=2024_08_28_fix
VELOX_HOME=""

OS=`uname -s`
Expand Down

0 comments on commit e404c13

Please sign in to comment.