Skip to content

Commit

Permalink
fix checkpverflow should have cast node
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed May 9, 2024
1 parent 0d68d43 commit 0870c48
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -449,14 +449,12 @@ object VeloxBackendSettings extends BackendSettingsApi {
override def fallbackAggregateWithEmptyOutputChild(): Boolean = true

override def recreateJoinExecOnFallback(): Boolean = true
override def rescaleDecimalLiteral(): Boolean = true
override def rescaleDecimalArithmetic(): Boolean = true

/** Get the config prefix for each backend */
override def getBackendConfigPrefix(): String =
GlutenConfig.GLUTEN_CONFIG_PREFIX + VeloxBackend.BACKEND_NAME

override def rescaleDecimalIntegralExpression(): Boolean = true

override def shuffleSupportedCodec(): Set[String] = SHUFFLE_SUPPORTED_CODEC

override def resolveNativeConf(nativeConf: java.util.Map[String, String]): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ trait BackendSettingsApi {
def supportShuffleWithProject(outputPartitioning: Partitioning, child: SparkPlan): Boolean = false
def utilizeShuffledHashJoinHint(): Boolean = false
def excludeScanExecFromCollapsedStage(): Boolean = false
def rescaleDecimalLiteral: Boolean = false
def rescaleDecimalArithmetic: Boolean = false

/**
* Whether to replace sort agg with hash agg., e.g., sort agg will be used in spark's planning for
Expand All @@ -98,8 +98,6 @@ trait BackendSettingsApi {

def allowDecimalArithmetic: Boolean = true

def rescaleDecimalIntegralExpression(): Boolean = false

def shuffleSupportedCodec(): Set[String]

def needOutputSchemaForPlan(): Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,12 +492,47 @@ object ExpressionConverter extends SQLConfHelper with Logging {
expr.children.map(
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
expr)
case CheckOverflow(b: BinaryArithmetic, decimalType, _) =>
genDecimalExpressionTransformer(
b,
getAndCheckSubstraitName(b, expressionsMap),
attributeSeq,
expressionsMap)
case c @ CheckOverflow(b: BinaryArithmetic, decimalType, _) =>
DecimalArithmeticUtil.checkAllowDecimalArithmetic()
if (BackendsApiManager.getSettings.rescaleDecimalArithmetic) {
val rescaleBinary = DecimalArithmeticUtil.rescaleLiteral(b)
val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
val resultType = DecimalArithmeticUtil.getResultType(
b,
left.dataType.asInstanceOf[DecimalType],
right.dataType.asInstanceOf[DecimalType]
)

val leftChild =
replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap)
val child = DecimalArithmeticExpressionTransformer(
getAndCheckSubstraitName(b, expressionsMap),
leftChild,
rightChild,
resultType,
b)
if (!resultType.equals(decimalType)) {
// in velox, add cast node
CheckOverflowTransformer(substraitExprName, child, c)
} else {
child
}
} else {
val leftChild =
replaceWithExpressionTransformerInternal(b.left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(b.right, attributeSeq, expressionsMap)
DecimalArithmeticExpressionTransformer(
getAndCheckSubstraitName(b, expressionsMap),
leftChild,
rightChild,
decimalType,
b)
}
case c: CheckOverflow =>
CheckOverflowTransformer(
substraitExprName,
Expand Down Expand Up @@ -625,33 +660,6 @@ object ExpressionConverter extends SQLConfHelper with Logging {
substraitExprName
}

private def genDecimalExpressionTransformer(
b: BinaryArithmetic,
substraitExprName: String,
attributeSeq: Seq[Attribute],
expressionsMap: Map[Class[_], String]) = {
DecimalArithmeticUtil.checkAllowDecimalArithmetic()
val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) {
DecimalArithmeticUtil.rescaleLiteral(b)
} else {
b
}
val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
val resultType = DecimalArithmeticUtil.getResultType(
b,
left.dataType.asInstanceOf[DecimalType],
right.dataType.asInstanceOf[DecimalType]
)
val leftChild =
replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap)

DecimalArithmeticExpressionTransformer(substraitExprName, leftChild, rightChild, resultType, b)
}

/**
* Transform BroadcastExchangeExec to ColumnarBroadcastExchangeExec in DynamicPruningExpression.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.gluten.utils

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.{CheckOverflowTransformer, ChildTransformer, DecimalArithmeticExpressionTransformer, ExpressionTransformer}
import org.apache.gluten.expression.ExpressionConverter.conf

import org.apache.spark.sql.catalyst.analysis.DecimalPrecision
Expand All @@ -27,8 +26,6 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType}
import org.apache.spark.sql.utils.DecimalTypeUtil

import scala.annotation.tailrec

object DecimalArithmeticUtil {

// Returns the result decimal type of a decimal arithmetic computing.
Expand Down Expand Up @@ -145,9 +142,7 @@ object DecimalArithmeticUtil {
if (isWiderType) (e1, newE2) else (e1, e2)
}

if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) {
(left, right)
} else if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {
if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {
// Have removed PromotePrecision(Cast(DecimalType)).
// Decimal * cast int.
doScale(left, right)
Expand All @@ -169,44 +164,28 @@ object DecimalArithmeticUtil {
* expression removed child PromotePrecision->Cast
*/
def removeCastForDecimal(arithmeticExpr: Expression): Expression = arithmeticExpr match {
case PromotePrecision(cast: Cast)
if cast.dataType.isInstanceOf[DecimalType] &&
cast.child.dataType.isInstanceOf[DecimalType] =>
cast.child
case PromotePrecision(_ @Cast(child, _: DecimalType, _, _))
if child.dataType.isInstanceOf[DecimalType] =>
child
case _ => arithmeticExpr
}

@tailrec
def getResultType(transformer: ExpressionTransformer): Option[DecimalType] = {
transformer match {
case ChildTransformer(child) =>
getResultType(child)
case CheckOverflowTransformer(_, _, original) =>
Some(original.dataType)
case DecimalArithmeticExpressionTransformer(_, _, _, resultType, _) =>
Some(resultType)
case _ => None
}
}

private def isPromoteCastIntegral(expr: Expression): Boolean = expr match {
case PromotePrecision(cast: Cast) if cast.dataType.isInstanceOf[DecimalType] =>
cast.child.dataType match {
case PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) =>
child.dataType match {
case IntegerType | ByteType | ShortType | LongType => true
case _ => false
}
case _ => false
}

private def rescaleCastForOneSide(expr: Expression): Expression = expr match {
case precision @ PromotePrecision(castInt: Cast)
if castInt.dataType.isInstanceOf[DecimalType]
&& BackendsApiManager.getSettings.rescaleDecimalIntegralExpression() =>
castInt.child.dataType match {
case precision @ PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) =>
child.dataType match {
case IntegerType | ByteType | ShortType =>
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(10, 0))))
precision.withNewChildren(Seq(Cast(child, DecimalType(10, 0))))
case LongType =>
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(20, 0))))
precision.withNewChildren(Seq(Cast(child, DecimalType(20, 0))))
case _ => expr
}
case _ => expr
Expand Down

0 comments on commit 0870c48

Please sign in to comment.