Skip to content

Commit

Permalink
fix decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGuangxin committed Jun 11, 2024
1 parent ed6b171 commit e726284
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.SparkConf
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType, StructField, StructType}
Expand Down Expand Up @@ -215,7 +216,7 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 1 PRECEDING AND CURRENT ROW), " +
"min(l_comment) over" +
" (partition by l_suppkey order by l_discount" +
" (partition by l_suppkey order by l_linenumber" +
" RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
}
Expand Down Expand Up @@ -255,6 +256,14 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
checkSparkOperatorMatch[WindowExecTransformer]
}

// DecimalType as order by column is not supported
runQueryAndCompare(
"select min(l_comment) over" +
" (partition by l_suppkey order by l_discount" +
" RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") {
checkSparkOperatorMatch[WindowExec]
}

runQueryAndCompare(
"select ntile(4) over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ object PullOutPreProject extends RewriteSingleNode with PullOutProjectHelper {
window.windowExpression.exists(_.find {
case we: WindowExpression =>
we.windowSpec.frameSpecification match {
case swf: SpecifiedWindowFrame if needPreComputeRangeFrame(swf) => true
case swf: SpecifiedWindowFrame
if needPreComputeRangeFrame(swf) && supportPreComputeRangeFrame(
we.windowSpec.orderSpec) =>
true
case _ => false
}
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType, ShortType}

import java.util.concurrent.atomic.AtomicInteger

Expand Down Expand Up @@ -174,6 +175,16 @@ trait PullOutProjectHelper {
(needPreComputeRangeFrameBoundary(swf.lower) || needPreComputeRangeFrameBoundary(swf.upper))
}

protected def supportPreComputeRangeFrame(sortOrders: Seq[SortOrder]): Boolean = {
sortOrders.forall {
_.dataType match {
case ByteType | ShortType | IntegerType | LongType | DateType => true
// Only integral type & date type are supported for sort key with Range Frame
case _ => false
}
}
}

protected def rewriteWindowExpression(
we: WindowExpression,
orderSpecs: Seq[SortOrder],
Expand Down

0 comments on commit e726284

Please sign in to comment.