Skip to content

Commit

Permalink
fix round function
Browse files Browse the repository at this point in the history
  • Loading branch information
loneylee committed Jul 12, 2024
1 parent 4911e99 commit f544787
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildS
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil}
import org.apache.spark.sql.extension.{CommonSubexpressionEliminateRule, RewriteDateTimestampComparisonRule}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.commons.lang3.ClassUtils
Expand Down Expand Up @@ -900,4 +900,10 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate

override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate

override def genDecimalRoundExpressionOutput(
decimalType: DecimalType,
toScale: Int): DecimalType = {
SparkShimLoader.getSparkShims.genDecimalRoundExpressionOutput(decimalType, toScale)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.types.{LongType, NullType, StructType}
import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

import java.lang.{Long => JLong}
Expand Down Expand Up @@ -712,4 +712,23 @@ trait SparkPlanExecApi {
arrowEvalPythonExec

def maybeCollapseTakeOrderedAndProject(plan: SparkPlan): SparkPlan = plan

def genDecimalRoundExpressionOutput(decimalType: DecimalType, toScale: Int): DecimalType = {
val p = decimalType.precision
val s = decimalType.scale
// After rounding we may need one more digit in the integral part,
// e.g. `ceil(9.9, 0)` -> `10`, `ceil(99, -1)` -> `100`.
val integralLeastNumDigits = p - s + 1
if (toScale < 0) {
// negative scale means we need to adjust `-scale` number of digits before the decimal
// point, which means we need at lease `-scale + 1` digits (after rounding).
val newPrecision = math.max(integralLeastNumDigits, -toScale + 1)
// We have to accept the risk of overflow as we can't exceed the max precision.
DecimalType(math.min(newPrecision, DecimalType.MAX_PRECISION), 0)
} else {
val newScale = math.min(s, toScale)
// We have to accept the risk of overflow as we can't exceed the max precision.
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.expression

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException

import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -29,25 +30,13 @@ case class DecimalRoundTransformer(

val toScale: Int = original.scale.eval(EmptyRow).asInstanceOf[Int]

// Use the same result type for different Spark versions.
// Use the same result type for different Spark versions in velox.
// The same result type with spark in ch.
override val dataType: DataType = original.child.dataType match {
case decimalType: DecimalType =>
val p = decimalType.precision
val s = decimalType.scale
// After rounding we may need one more digit in the integral part,
// e.g. `ceil(9.9, 0)` -> `10`, `ceil(99, -1)` -> `100`.
val integralLeastNumDigits = p - s + 1
if (toScale < 0) {
// negative scale means we need to adjust `-scale` number of digits before the decimal
// point, which means we need at lease `-scale + 1` digits (after rounding).
val newPrecision = math.max(integralLeastNumDigits, -toScale + 1)
// We have to accept the risk of overflow as we can't exceed the max precision.
DecimalType(math.min(newPrecision, DecimalType.MAX_PRECISION), 0)
} else {
val newScale = math.min(s, toScale)
// We have to accept the risk of overflow as we can't exceed the max precision.
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
BackendsApiManager.getSparkPlanExecApiInstance.genDecimalRoundExpressionOutput(
decimalType,
toScale)
case _ =>
throw new GlutenNotSupportException(
s"Decimal type is expected but received ${original.child.dataType.typeName}.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ import org.apache.spark.sql.connector.read.{InputPartition, Scan}
import org.apache.spark.sql.execution.{FileSourceScanExec, GlobalLimitExec, SparkPlan, TakeOrderedAndProjectExec}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetFilters}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.text.TextScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.{BlockId, BlockManagerId}

Expand Down Expand Up @@ -248,4 +248,23 @@ trait SparkShims {
conf: SQLConf,
schema: MessageType,
caseSensitive: Option[Boolean] = None): ParquetFilters

def genDecimalRoundExpressionOutput(decimalType: DecimalType, toScale: Int): DecimalType = {
val p = decimalType.precision
val s = decimalType.scale
// After rounding we may need one more digit in the integral part,
// e.g. `ceil(9.9, 0)` -> `10`, `ceil(99, -1)` -> `100`.
val integralLeastNumDigits = p - s + 1
if (toScale < 0) {
// negative scale means we need to adjust `-scale` number of digits before the decimal
// point, which means we need at lease `-scale + 1` digits (after rounding).
val newPrecision = math.max(integralLeastNumDigits, -toScale + 1)
// We have to accept the risk of overflow as we can't exceed the max precision.
DecimalType(math.min(newPrecision, DecimalType.MAX_PRECISION), 0)
} else {
val newScale = math.min(s, toScale)
// We have to accept the risk of overflow as we can't exceed the max precision.
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.execution.datasources.v2.utils.CatalogUtil
import org.apache.spark.sql.execution.exchange.BroadcastExchangeLike
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{DecimalType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.{BlockId, BlockManagerId}

Expand Down Expand Up @@ -276,4 +276,10 @@ class Spark32Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}

override def genDecimalRoundExpressionOutput(decimalType: DecimalType, toScale: Int): DecimalType = {
val p = decimalType.precision
val s = decimalType.scale
DecimalType(p, if (toScale > s) s else toScale)
}
}

0 comments on commit f544787

Please sign in to comment.