diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 554b3791dad3..4755adc91245 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -159,7 +159,15 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { } else if (SparkShimLoader.getSparkShims.withAnsiEvalMode(original)) { throw new GlutenNotSupportException(s"$substraitExprName with ansi mode is not supported") } else { - GenericExpressionTransformer(substraitExprName, Seq(left, right), original) + if ( + left.dataType.isInstanceOf[DecimalType] && right.dataType + .isInstanceOf[DecimalType] && !SQLConf.get.decimalOperationsAllowPrecisionLoss + ) { + val newName = "not_allow_precision_loss_" + GenericExpressionTransformer(newName, Seq(left, right), original) + } else { + GenericExpressionTransformer(substraitExprName, Seq(left, right), original) + } } } diff --git a/cpp/velox/compute/WholeStageResultIterator.cc b/cpp/velox/compute/WholeStageResultIterator.cc index c306564dc8d9..cd72c714cb1f 100644 --- a/cpp/velox/compute/WholeStageResultIterator.cc +++ b/cpp/velox/compute/WholeStageResultIterator.cc @@ -96,7 +96,7 @@ WholeStageResultIterator::WholeStageResultIterator( 0, std::move(queryCtx), velox::exec::Task::ExecutionMode::kSerial); - if (!task_->supportsSingleThreadedExecution()) { + if (!task_->supportSerialExecutionMode()) { throw std::runtime_error("Task doesn't support single thread execution: " + planNode->toString()); } auto fileSystem = velox::filesystems::getFileSystem(spillDir, nullptr); @@ -445,10 +445,6 @@ std::unordered_map WholeStageResultIterator::getQueryC // Adjust timestamp according to the above configured session timezone. configs[velox::core::QueryConfig::kAdjustTimestampToTimezone] = "true"; - // To align with Spark's behavior, allow decimal precision loss or not. - configs[velox::core::QueryConfig::kSparkDecimalOperationsAllowPrecisionLoss] = - veloxCfg_->get(kAllowPrecisionLoss, "true"); - { // partial aggregation memory config auto offHeapMemory = veloxCfg_->get(kSparkTaskOffHeapMemory, facebook::velox::memory::kMaxMemory); diff --git a/cpp/velox/memory/VeloxMemoryManager.cc b/cpp/velox/memory/VeloxMemoryManager.cc index 6b5606dd228e..dc6ad6317c0d 100644 --- a/cpp/velox/memory/VeloxMemoryManager.cc +++ b/cpp/velox/memory/VeloxMemoryManager.cc @@ -316,22 +316,27 @@ bool VeloxMemoryManager::tryDestructSafe() { // Velox memory manager considered safe to destruct when no alive pools. if (veloxMemoryManager_) { - if (veloxMemoryManager_->numPools() > 1) { + if (veloxMemoryManager_->numPools() > 2) { return false; } - if (veloxMemoryManager_->numPools() == 1) { + if (veloxMemoryManager_->numPools() == 2) { // Assert the pool is spill pool // See https://github.com/facebookincubator/velox/commit/e6f84e8ac9ef6721f527a2d552a13f7e79bdf72e int32_t spillPoolCount = 0; + int32_t tracePoolCount = 0; veloxMemoryManager_->testingDefaultRoot().visitChildren([&](velox::memory::MemoryPool* child) -> bool { if (child == veloxMemoryManager_->spillPool()) { spillPoolCount++; } + if (child == veloxMemoryManager_->tracePool()) { + tracePoolCount++; + } return true; }); GLUTEN_CHECK(spillPoolCount == 1, "Illegal pool count state: spillPoolCount: " + std::to_string(spillPoolCount)); + GLUTEN_CHECK(tracePoolCount == 1, "Illegal pool count state: tracePoolCount: " + std::to_string(tracePoolCount)); } - if (veloxMemoryManager_->numPools() < 1) { + if (veloxMemoryManager_->numPools() < 2) { GLUTEN_CHECK(false, "Unreachable code"); } } diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc b/cpp/velox/operators/functions/RegistrationAllFunctions.cc index 6b6564fa4aa3..6e2f90f0105b 100644 --- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc +++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc @@ -26,6 +26,7 @@ #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" +#include "velox/functions/sparksql/DecimalArithmetic.h" #include "velox/functions/sparksql/Hash.h" #include "velox/functions/sparksql/Rand.h" #include "velox/functions/sparksql/Register.h" @@ -74,6 +75,14 @@ void registerFunctionOverwrite() { velox::functions::registerPrestoVectorFunctions(); } + +void registerFunctionForConfig() { + const std::string prefix = "not_allow_precision_loss_"; + velox::functions::sparksql::registerDecimalAdd(prefix, false); + velox::functions::sparksql::registerDecimalSubtract(prefix, false); + velox::functions::sparksql::registerDecimalMultiply(prefix, false); + velox::functions::sparksql::registerDecimalDivide(prefix, false); +} } // namespace void registerAllFunctions() { @@ -87,6 +96,7 @@ void registerAllFunctions() { // Using function overwrite to handle function names mismatch between Spark // and Velox. registerFunctionOverwrite(); + registerFunctionForConfig(); } } // namespace gluten diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index 003a3790cedf..f6394f850231 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -17,7 +17,7 @@ set -exu VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_08_27 +VELOX_BRANCH=2024_09_01 VELOX_HOME="" OS=`uname -s`