diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index aad8ff5d5d55..61a46e15f51f 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -449,14 +449,12 @@ object VeloxBackendSettings extends BackendSettingsApi { override def fallbackAggregateWithEmptyOutputChild(): Boolean = true override def recreateJoinExecOnFallback(): Boolean = true - override def rescaleDecimalLiteral(): Boolean = true + override def rescaleDecimalArithmetic(): Boolean = true /** Get the config prefix for each backend */ override def getBackendConfigPrefix(): String = GlutenConfig.GLUTEN_CONFIG_PREFIX + VeloxBackend.BACKEND_NAME - override def rescaleDecimalIntegralExpression(): Boolean = true - override def shuffleSupportedCodec(): Set[String] = SHUFFLE_SUPPORTED_CODEC override def resolveNativeConf(nativeConf: java.util.Map[String, String]): Unit = {} diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index e34cb1562092..8ed8e71306ba 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -85,7 +85,7 @@ trait BackendSettingsApi { def supportShuffleWithProject(outputPartitioning: Partitioning, child: SparkPlan): Boolean = false def utilizeShuffledHashJoinHint(): Boolean = false def excludeScanExecFromCollapsedStage(): Boolean = false - def rescaleDecimalLiteral: Boolean = false + def rescaleDecimalArithmetic: Boolean = false /** * Whether to replace sort agg with hash agg., e.g., sort agg will be used in spark's planning for @@ -98,8 +98,6 @@ trait BackendSettingsApi { def allowDecimalArithmetic: Boolean = true - def rescaleDecimalIntegralExpression(): Boolean = false - def shuffleSupportedCodec(): Set[String] def needOutputSchemaForPlan(): Boolean = false diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index e0d37f971b71..7f83840dce1c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -492,12 +492,47 @@ object ExpressionConverter extends SQLConfHelper with Logging { expr.children.map( replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), expr) - case CheckOverflow(b: BinaryArithmetic, decimalType, _) => - genDecimalExpressionTransformer( - b, - getAndCheckSubstraitName(b, expressionsMap), - attributeSeq, - expressionsMap) + case c @ CheckOverflow(b: BinaryArithmetic, decimalType, _) => + DecimalArithmeticUtil.checkAllowDecimalArithmetic() + if (BackendsApiManager.getSettings.rescaleDecimalArithmetic) { + val rescaleBinary = DecimalArithmeticUtil.rescaleLiteral(b) + val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal( + DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left), + DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right)) + val resultType = DecimalArithmeticUtil.getResultType( + b, + left.dataType.asInstanceOf[DecimalType], + right.dataType.asInstanceOf[DecimalType] + ) + + val leftChild = + replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap) + val rightChild = + replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap) + val child = DecimalArithmeticExpressionTransformer( + getAndCheckSubstraitName(b, expressionsMap), + leftChild, + rightChild, + resultType, + b) + if (!resultType.equals(decimalType)) { + // in velox, add cast node + CheckOverflowTransformer(substraitExprName, child, c) + } else { + child + } + } else { + val leftChild = + replaceWithExpressionTransformerInternal(b.left, attributeSeq, expressionsMap) + val rightChild = + replaceWithExpressionTransformerInternal(b.right, attributeSeq, expressionsMap) + DecimalArithmeticExpressionTransformer( + getAndCheckSubstraitName(b, expressionsMap), + leftChild, + rightChild, + decimalType, + b) + } case c: CheckOverflow => CheckOverflowTransformer( substraitExprName, @@ -625,33 +660,6 @@ object ExpressionConverter extends SQLConfHelper with Logging { substraitExprName } - private def genDecimalExpressionTransformer( - b: BinaryArithmetic, - substraitExprName: String, - attributeSeq: Seq[Attribute], - expressionsMap: Map[Class[_], String]) = { - DecimalArithmeticUtil.checkAllowDecimalArithmetic() - val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) { - DecimalArithmeticUtil.rescaleLiteral(b) - } else { - b - } - val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal( - DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left), - DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right)) - val resultType = DecimalArithmeticUtil.getResultType( - b, - left.dataType.asInstanceOf[DecimalType], - right.dataType.asInstanceOf[DecimalType] - ) - val leftChild = - replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap) - val rightChild = - replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap) - - DecimalArithmeticExpressionTransformer(substraitExprName, leftChild, rightChild, resultType, b) - } - /** * Transform BroadcastExchangeExec to ColumnarBroadcastExchangeExec in DynamicPruningExpression. * diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala index 70860dd09cd3..cc22734bad04 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala @@ -18,7 +18,6 @@ package org.apache.gluten.utils import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.GlutenNotSupportException -import org.apache.gluten.expression.{CheckOverflowTransformer, ChildTransformer, DecimalArithmeticExpressionTransformer, ExpressionTransformer} import org.apache.gluten.expression.ExpressionConverter.conf import org.apache.spark.sql.catalyst.analysis.DecimalPrecision @@ -27,8 +26,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType} import org.apache.spark.sql.utils.DecimalTypeUtil -import scala.annotation.tailrec - object DecimalArithmeticUtil { // Returns the result decimal type of a decimal arithmetic computing. @@ -145,9 +142,7 @@ object DecimalArithmeticUtil { if (isWiderType) (e1, newE2) else (e1, e2) } - if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) { - (left, right) - } else if (!isPromoteCast(left) && isPromoteCastIntegral(right)) { + if (!isPromoteCast(left) && isPromoteCastIntegral(right)) { // Have removed PromotePrecision(Cast(DecimalType)). // Decimal * cast int. doScale(left, right) @@ -169,29 +164,15 @@ object DecimalArithmeticUtil { * expression removed child PromotePrecision->Cast */ def removeCastForDecimal(arithmeticExpr: Expression): Expression = arithmeticExpr match { - case PromotePrecision(cast: Cast) - if cast.dataType.isInstanceOf[DecimalType] && - cast.child.dataType.isInstanceOf[DecimalType] => - cast.child + case PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) + if child.dataType.isInstanceOf[DecimalType] => + child case _ => arithmeticExpr } - @tailrec - def getResultType(transformer: ExpressionTransformer): Option[DecimalType] = { - transformer match { - case ChildTransformer(child) => - getResultType(child) - case CheckOverflowTransformer(_, _, original) => - Some(original.dataType) - case DecimalArithmeticExpressionTransformer(_, _, _, resultType, _) => - Some(resultType) - case _ => None - } - } - private def isPromoteCastIntegral(expr: Expression): Boolean = expr match { - case PromotePrecision(cast: Cast) if cast.dataType.isInstanceOf[DecimalType] => - cast.child.dataType match { + case PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) => + child.dataType match { case IntegerType | ByteType | ShortType | LongType => true case _ => false } @@ -199,14 +180,12 @@ object DecimalArithmeticUtil { } private def rescaleCastForOneSide(expr: Expression): Expression = expr match { - case precision @ PromotePrecision(castInt: Cast) - if castInt.dataType.isInstanceOf[DecimalType] - && BackendsApiManager.getSettings.rescaleDecimalIntegralExpression() => - castInt.child.dataType match { + case precision @ PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) => + child.dataType match { case IntegerType | ByteType | ShortType => - precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(10, 0)))) + precision.withNewChildren(Seq(Cast(child, DecimalType(10, 0)))) case LongType => - precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(20, 0)))) + precision.withNewChildren(Seq(Cast(child, DecimalType(20, 0)))) case _ => expr } case _ => expr