Skip to content

Commit

Permalink
[GLUTEN-5620][CORE] Remove check_overflow and refactor code (#5654)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh authored May 14, 2024
1 parent 4907f25 commit e807856
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class CHTransformerApi extends TransformerApi with Logging {
args: java.lang.Object,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,14 +449,12 @@ object VeloxBackendSettings extends BackendSettingsApi {
override def fallbackAggregateWithEmptyOutputChild(): Boolean = true

override def recreateJoinExecOnFallback(): Boolean = true
override def rescaleDecimalLiteral(): Boolean = true
override def rescaleDecimalArithmetic(): Boolean = true

/** Get the config prefix for each backend */
override def getBackendConfigPrefix(): String =
GlutenConfig.GLUTEN_CONFIG_PREFIX + VeloxBackend.BACKEND_NAME

override def rescaleDecimalIntegralExpression(): Boolean = true

override def shuffleSupportedCodec(): Set[String] = SHUFFLE_SUPPORTED_CODEC

override def resolveNativeConf(nativeConf: java.util.Map[String, String]): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,16 @@ class VeloxTransformerApi extends TransformerApi with Logging {
args: java.lang.Object,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode = {
val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
ExpressionBuilder.makeCast(typeNode, childNode, !nullOnOverflow)
if (childResultType.equals(dataType)) {
childNode
} else {
val typeNode = ConverterUtils.getTypeNode(dataType, nullable)
ExpressionBuilder.makeCast(typeNode, childNode, !nullOnOverflow)
}
}

override def getNativePlanString(substraitPlan: Array[Byte], details: Boolean): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ trait BackendSettingsApi {
def supportShuffleWithProject(outputPartitioning: Partitioning, child: SparkPlan): Boolean = false
def utilizeShuffledHashJoinHint(): Boolean = false
def excludeScanExecFromCollapsedStage(): Boolean = false
def rescaleDecimalLiteral: Boolean = false
def rescaleDecimalArithmetic: Boolean = false

/**
* Whether to replace sort agg with hash agg., e.g., sort agg will be used in spark's planning for
Expand All @@ -106,8 +106,6 @@ trait BackendSettingsApi {
*/
def transformCheckOverflow: Boolean = true

def rescaleDecimalIntegralExpression(): Boolean = false

def shuffleSupportedCodec(): Set[String]

def needOutputSchemaForPlan(): Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory}
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.{DataType, DecimalType}
import org.apache.spark.util.collection.BitSet

import com.google.protobuf.{Any, Message}
Expand Down Expand Up @@ -69,6 +69,7 @@ trait TransformerApi {
args: java.lang.Object,
substraitExprName: String,
childNode: ExpressionNode,
childResultType: DataType,
dataType: DecimalType,
nullable: Boolean,
nullOnOverflow: Boolean): ExpressionNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ object ExpressionConverter extends SQLConfHelper with Logging {
}
}

private def genRescaleDecimalTransformer(
substraitName: String,
b: BinaryArithmetic,
attributeSeq: Seq[Attribute],
expressionsMap: Map[Class[_], String]): DecimalArithmeticExpressionTransformer = {
val rescaleBinary = DecimalArithmeticUtil.rescaleLiteral(b)
val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
val resultType = DecimalArithmeticUtil.getResultType(
b,
left.dataType.asInstanceOf[DecimalType],
right.dataType.asInstanceOf[DecimalType]
)

val leftChild =
replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap)
DecimalArithmeticExpressionTransformer(substraitName, leftChild, rightChild, resultType, b)
}

private def replaceWithExpressionTransformerInternal(
expr: Expression,
attributeSeq: Seq[Attribute],
Expand Down Expand Up @@ -492,7 +514,6 @@ object ExpressionConverter extends SQLConfHelper with Logging {
expr.children.map(
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
expr)

case CheckOverflow(b: BinaryArithmetic, decimalType, _)
if !BackendsApiManager.getSettings.transformCheckOverflow &&
DecimalArithmeticUtil.isDecimalArithmetic(b) =>
Expand All @@ -507,55 +528,25 @@ object ExpressionConverter extends SQLConfHelper with Logging {
rightChild,
decimalType,
b)

case c: CheckOverflow =>
CheckOverflowTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap),
c.child.dataType,
c)

case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) =>
DecimalArithmeticUtil.checkAllowDecimalArithmetic()
if (!BackendsApiManager.getSettings.transformCheckOverflow) {
val leftChild =
replaceWithExpressionTransformerInternal(b.left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(b.right, attributeSeq, expressionsMap)
DecimalArithmeticExpressionTransformer(
GenericExpressionTransformer(
substraitExprName,
leftChild,
rightChild,
b.dataType.asInstanceOf[DecimalType],
b)
} else {
val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) {
DecimalArithmeticUtil.rescaleLiteral(b)
} else {
b
}
val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
val leftChild =
replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap)

val resultType = DecimalArithmeticUtil.getResultTypeForOperation(
DecimalArithmeticUtil.getOperationType(b),
DecimalArithmeticUtil
.getResultType(leftChild)
.getOrElse(left.dataType.asInstanceOf[DecimalType]),
DecimalArithmeticUtil
.getResultType(rightChild)
.getOrElse(right.dataType.asInstanceOf[DecimalType])
expr.children.map(
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
expr
)
DecimalArithmeticExpressionTransformer(
substraitExprName,
leftChild,
rightChild,
resultType,
b)
} else {
// Without the rescale and remove cast, result is right for high version Spark,
// but performance regression in velox
genRescaleDecimalTransformer(substraitExprName, b, attributeSeq, expressionsMap)
}
case n: NaNvl =>
BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ case class PosExplodeTransformer(
case class CheckOverflowTransformer(
substraitExprName: String,
child: ExpressionTransformer,
childResultType: DataType,
original: CheckOverflow)
extends ExpressionTransformer {

Expand All @@ -160,6 +161,7 @@ case class CheckOverflowTransformer(
args,
substraitExprName,
child.doTransform(args),
childResultType,
original.dataType,
original.nullable,
original.nullOnOverflow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,69 +18,40 @@ package org.apache.gluten.utils

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.{CheckOverflowTransformer, ChildTransformer, DecimalArithmeticExpressionTransformer, ExpressionTransformer}
import org.apache.gluten.expression.ExpressionConverter.conf

import org.apache.spark.sql.catalyst.analysis.DecimalPrecision
import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, Subtract}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType}

import scala.annotation.tailrec
import org.apache.spark.sql.utils.DecimalTypeUtil

object DecimalArithmeticUtil {

object OperationType extends Enumeration {
type Config = Value
val ADD, SUBTRACT, MULTIPLY, DIVIDE, MOD = Value
}

private val MIN_ADJUSTED_SCALE = 6
val MAX_PRECISION = 38

// Returns the result decimal type of a decimal arithmetic computing.
def getResultTypeForOperation(
operationType: OperationType.Config,
type1: DecimalType,
type2: DecimalType): DecimalType = {
def getResultType(expr: BinaryArithmetic, type1: DecimalType, type2: DecimalType): DecimalType = {
var resultScale = 0
var resultPrecision = 0
operationType match {
case OperationType.ADD =>
expr match {
case _: Add =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
resultScale + Math.max(type1.precision - type1.scale, type2.precision - type2.scale) + 1
case OperationType.SUBTRACT =>
case _: Subtract =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
resultScale + Math.max(type1.precision - type1.scale, type2.precision - type2.scale) + 1
case OperationType.MULTIPLY =>
case _: Multiply =>
resultScale = type1.scale + type2.scale
resultPrecision = type1.precision + type2.precision + 1
case OperationType.DIVIDE =>
resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + type2.precision + 1)
case _: Divide =>
resultScale =
Math.max(DecimalType.MINIMUM_ADJUSTED_SCALE, type1.scale + type2.precision + 1)
resultPrecision = type1.precision - type1.scale + type2.scale + resultScale
case OperationType.MOD =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
Math.min(type1.precision - type1.scale, type2.precision - type2.scale + resultScale)
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
}
adjustScaleIfNeeded(resultPrecision, resultScale)
}

// Returns the adjusted decimal type when the precision is larger the maximum.
private def adjustScaleIfNeeded(precision: Int, scale: Int): DecimalType = {
var typePrecision = precision
var typeScale = scale
if (precision > MAX_PRECISION) {
val minScale = Math.min(scale, MIN_ADJUSTED_SCALE)
val delta = precision - MAX_PRECISION
typePrecision = MAX_PRECISION
typeScale = Math.max(scale - delta, minScale)
}
DecimalType(typePrecision, typeScale)
DecimalTypeUtil.adjustPrecisionScale(resultPrecision, resultScale)
}

// If casting between DecimalType, unnecessary cast is skipped to avoid data loss,
Expand All @@ -98,18 +69,6 @@ object DecimalArithmeticUtil {
} else false
}

// Returns the operation type of a binary arithmetic expression.
def getOperationType(b: BinaryArithmetic): OperationType.Config = {
b match {
case _: Add => OperationType.ADD
case _: Subtract => OperationType.SUBTRACT
case _: Multiply => OperationType.MULTIPLY
case _: Divide => OperationType.DIVIDE
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
}
}

// For decimal * 10 case, dec will be Decimal(38, 18), then the result precision is wrong,
// so here we will get the real precision and scale of the literal.
private def getNewPrecisionScale(dec: Decimal): (Integer, Integer) = {
Expand Down Expand Up @@ -179,9 +138,7 @@ object DecimalArithmeticUtil {
if (isWiderType) (e1, newE2) else (e1, e2)
}

if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) {
(left, right)
} else if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {
if (!isPromoteCast(left) && isPromoteCastIntegral(right)) {
// Have removed PromotePrecision(Cast(DecimalType)).
// Decimal * cast int.
doScale(left, right)
Expand All @@ -202,66 +159,32 @@ object DecimalArithmeticUtil {
* @return
* expression removed child PromotePrecision->Cast
*/
def removeCastForDecimal(arithmeticExpr: Expression): Expression = {
arithmeticExpr match {
case precision: PromotePrecision =>
precision.child match {
case cast: Cast
if cast.dataType.isInstanceOf[DecimalType]
&& cast.child.dataType.isInstanceOf[DecimalType] =>
cast.child
case _ => arithmeticExpr
}
case _ => arithmeticExpr
}
def removeCastForDecimal(arithmeticExpr: Expression): Expression = arithmeticExpr match {
case PromotePrecision(_ @Cast(child, _: DecimalType, _, _))
if child.dataType.isInstanceOf[DecimalType] =>
child
case _ => arithmeticExpr
}

@tailrec
def getResultType(transformer: ExpressionTransformer): Option[DecimalType] = {
transformer match {
case ChildTransformer(child) =>
getResultType(child)
case CheckOverflowTransformer(_, _, original) =>
Some(original.dataType)
case DecimalArithmeticExpressionTransformer(_, _, _, resultType, _) =>
Some(resultType)
case _ => None
}
}

private def isPromoteCastIntegral(expr: Expression): Boolean = {
expr match {
case precision: PromotePrecision =>
precision.child match {
case cast: Cast if cast.dataType.isInstanceOf[DecimalType] =>
cast.child.dataType match {
case IntegerType | ByteType | ShortType | LongType => true
case _ => false
}
case _ => false
}
case _ => false
}
private def isPromoteCastIntegral(expr: Expression): Boolean = expr match {
case PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) =>
child.dataType match {
case IntegerType | ByteType | ShortType | LongType => true
case _ => false
}
case _ => false
}

private def rescaleCastForOneSide(expr: Expression): Expression = {
expr match {
case precision: PromotePrecision =>
precision.child match {
case castInt: Cast
if castInt.dataType.isInstanceOf[DecimalType] &&
BackendsApiManager.getSettings.rescaleDecimalIntegralExpression() =>
castInt.child.dataType match {
case IntegerType | ByteType | ShortType =>
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(10, 0))))
case LongType =>
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(20, 0))))
case _ => expr
}
case _ => expr
}
case _ => expr
}
private def rescaleCastForOneSide(expr: Expression): Expression = expr match {
case precision @ PromotePrecision(_ @Cast(child, _: DecimalType, _, _)) =>
child.dataType match {
case IntegerType | ByteType | ShortType =>
precision.withNewChildren(Seq(Cast(child, DecimalType(10, 0))))
case LongType =>
precision.withNewChildren(Seq(Cast(child, DecimalType(20, 0))))
case _ => expr
}
case _ => expr
}

private def checkIsWiderType(
Expand Down
Loading

0 comments on commit e807856

Please sign in to comment.