Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed May 9, 2024
1 parent 843a1a5 commit 0d68d43
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -492,22 +492,17 @@ object ExpressionConverter extends SQLConfHelper with Logging {
expr.children.map(
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
expr)

case CheckOverflow(b: BinaryArithmetic, decimalType, _) =>
genDecimalExpressionTransformer(
b,
getAndCheckSubstraitName(b, expressionsMap),
decimalType,
attributeSeq,
expressionsMap)

case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) =>
genDecimalExpressionTransformer(
b,
case c: CheckOverflow =>
CheckOverflowTransformer(
substraitExprName,
b.dataType.asInstanceOf[DecimalType],
attributeSeq,
expressionsMap)
replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap),
c)
case n: NaNvl =>
BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer(
substraitExprName,
Expand Down Expand Up @@ -633,7 +628,6 @@ object ExpressionConverter extends SQLConfHelper with Logging {
private def genDecimalExpressionTransformer(
b: BinaryArithmetic,
substraitExprName: String,
resultType: DecimalType,
attributeSeq: Seq[Attribute],
expressionsMap: Map[Class[_], String]) = {
DecimalArithmeticUtil.checkAllowDecimalArithmetic()
Expand All @@ -645,16 +639,17 @@ object ExpressionConverter extends SQLConfHelper with Logging {
val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
val resultType = DecimalArithmeticUtil.getResultType(
b,
left.dataType.asInstanceOf[DecimalType],
right.dataType.asInstanceOf[DecimalType]
)
val leftChild =
replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap)
DecimalArithmeticExpressionTransformer(
substraitExprName,
leftChild,
rightChild,
b.dataType.asInstanceOf[DecimalType],
b)

DecimalArithmeticExpressionTransformer(substraitExprName, leftChild, rightChild, resultType, b)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,62 +25,40 @@ import org.apache.spark.sql.catalyst.analysis.DecimalPrecision
import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, Subtract}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType}
import org.apache.spark.sql.utils.DecimalTypeUtil

import scala.annotation.tailrec

object DecimalArithmeticUtil {

object OperationType extends Enumeration {
type Config = Value
val ADD, SUBTRACT, MULTIPLY, DIVIDE, MOD = Value
}

private val MIN_ADJUSTED_SCALE = 6
val MAX_PRECISION = 38

// Returns the result decimal type of a decimal arithmetic computing.
def getResultTypeForOperation(
operationType: OperationType.Config,
type1: DecimalType,
type2: DecimalType): DecimalType = {
def getResultType(expr: BinaryArithmetic, type1: DecimalType, type2: DecimalType): DecimalType = {
var resultScale = 0
var resultPrecision = 0
operationType match {
case OperationType.ADD =>
expr match {
case _: Add =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
resultScale + Math.max(type1.precision - type1.scale, type2.precision - type2.scale) + 1
case OperationType.SUBTRACT =>
case _: Subtract =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
resultScale + Math.max(type1.precision - type1.scale, type2.precision - type2.scale) + 1
case OperationType.MULTIPLY =>
case _: Multiply =>
resultScale = type1.scale + type2.scale
resultPrecision = type1.precision + type2.precision + 1
case OperationType.DIVIDE =>
resultScale = Math.max(MIN_ADJUSTED_SCALE, type1.scale + type2.precision + 1)
case _: Divide =>
resultScale =
Math.max(DecimalType.MINIMUM_ADJUSTED_SCALE, type1.scale + type2.precision + 1)
resultPrecision = type1.precision - type1.scale + type2.scale + resultScale
case OperationType.MOD =>
case _: Pmod =>
resultScale = Math.max(type1.scale, type2.scale)
resultPrecision =
Math.min(type1.precision - type1.scale, type2.precision - type2.scale + resultScale)
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
}
adjustScaleIfNeeded(resultPrecision, resultScale)
}

// Returns the adjusted decimal type when the precision is larger the maximum.
private def adjustScaleIfNeeded(precision: Int, scale: Int): DecimalType = {
var typePrecision = precision
var typeScale = scale
if (precision > MAX_PRECISION) {
val minScale = Math.min(scale, MIN_ADJUSTED_SCALE)
val delta = precision - MAX_PRECISION
typePrecision = MAX_PRECISION
typeScale = Math.max(scale - delta, minScale)
}
DecimalType(typePrecision, typeScale)
DecimalTypeUtil.adjustPrecisionScale(resultPrecision, resultScale)
}

// If casting between DecimalType, unnecessary cast is skipped to avoid data loss,
Expand All @@ -98,18 +76,6 @@ object DecimalArithmeticUtil {
} else false
}

// Returns the operation type of a binary arithmetic expression.
def getOperationType(b: BinaryArithmetic): OperationType.Config = {
b match {
case _: Add => OperationType.ADD
case _: Subtract => OperationType.SUBTRACT
case _: Multiply => OperationType.MULTIPLY
case _: Divide => OperationType.DIVIDE
case other =>
throw new GlutenNotSupportException(s"$other is not supported.")
}
}

// For decimal * 10 case, dec will be Decimal(38, 18), then the result precision is wrong,
// so here we will get the real precision and scale of the literal.
private def getNewPrecisionScale(dec: Decimal): (Integer, Integer) = {
Expand Down Expand Up @@ -202,18 +168,12 @@ object DecimalArithmeticUtil {
* @return
* expression removed child PromotePrecision->Cast
*/
def removeCastForDecimal(arithmeticExpr: Expression): Expression = {
arithmeticExpr match {
case precision: PromotePrecision =>
precision.child match {
case cast: Cast
if cast.dataType.isInstanceOf[DecimalType]
&& cast.child.dataType.isInstanceOf[DecimalType] =>
cast.child
case _ => arithmeticExpr
}
case _ => arithmeticExpr
}
def removeCastForDecimal(arithmeticExpr: Expression): Expression = arithmeticExpr match {
case PromotePrecision(cast: Cast)
if cast.dataType.isInstanceOf[DecimalType] &&
cast.child.dataType.isInstanceOf[DecimalType] =>
cast.child
case _ => arithmeticExpr
}

@tailrec
Expand All @@ -229,39 +189,27 @@ object DecimalArithmeticUtil {
}
}

private def isPromoteCastIntegral(expr: Expression): Boolean = {
expr match {
case precision: PromotePrecision =>
precision.child match {
case cast: Cast if cast.dataType.isInstanceOf[DecimalType] =>
cast.child.dataType match {
case IntegerType | ByteType | ShortType | LongType => true
case _ => false
}
case _ => false
}
case _ => false
}
private def isPromoteCastIntegral(expr: Expression): Boolean = expr match {
case PromotePrecision(cast: Cast) if cast.dataType.isInstanceOf[DecimalType] =>
cast.child.dataType match {
case IntegerType | ByteType | ShortType | LongType => true
case _ => false
}
case _ => false
}

private def rescaleCastForOneSide(expr: Expression): Expression = {
expr match {
case precision: PromotePrecision =>
precision.child match {
case castInt: Cast
if castInt.dataType.isInstanceOf[DecimalType] &&
BackendsApiManager.getSettings.rescaleDecimalIntegralExpression() =>
castInt.child.dataType match {
case IntegerType | ByteType | ShortType =>
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(10, 0))))
case LongType =>
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(20, 0))))
case _ => expr
}
case _ => expr
}
case _ => expr
}
private def rescaleCastForOneSide(expr: Expression): Expression = expr match {
case precision @ PromotePrecision(castInt: Cast)
if castInt.dataType.isInstanceOf[DecimalType]
&& BackendsApiManager.getSettings.rescaleDecimalIntegralExpression() =>
castInt.child.dataType match {
case IntegerType | ByteType | ShortType =>
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(10, 0))))
case LongType =>
precision.withNewChildren(Seq(Cast(castInt.child, DecimalType(20, 0))))
case _ => expr
}
case _ => expr
}

private def checkIsWiderType(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.utils

import org.apache.spark.sql.types.DecimalType

object DecimalTypeUtil {
def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = {
DecimalType.adjustPrecisionScale(precision, scale)
}

}

0 comments on commit 0d68d43

Please sign in to comment.