Skip to content

Commit

Permalink
[GLUTEN-2031][VL] Enable lag window function (apache#2737)
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE authored Feb 28, 2024
1 parent eb276b1 commit 65fb100
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFormat, OrcReadFormat, ParquetReadFormat}

import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Literal, NamedExpression, NthValue, NTile, PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame}
import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, Descending, Expression, Lag, Literal, NamedExpression, NthValue, NTile, PercentRank, Rand, RangeFrame, Rank, RowNumber, SortOrder, SpecialFrameBoundary, SpecifiedWindowFrame}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
Expand Down Expand Up @@ -286,7 +286,7 @@ object BackendSettings extends BackendSettingsApi {
}
windowExpression.windowFunction match {
case _: RowNumber | _: AggregateExpression | _: Rank | _: CumeDist | _: DenseRank |
_: PercentRank | _: NthValue | _: NTile =>
_: PercentRank | _: NthValue | _: NTile | _: Lag =>
case _ =>
allSupported = false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
assertWindowOffloaded
}

runQueryAndCompare(
"select lag(l_orderkey) over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
assertWindowOffloaded
}

// Test same partition/ordering keys.
runQueryAndCompare(
"select avg(l_partkey) over" +
Expand Down
2 changes: 1 addition & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace gluten {
class ResultIterator;

// Holds names of Spark OffsetWindowFunctions.
static const std::unordered_set<std::string> kOffsetWindowFunctions = {"nth_value"};
static const std::unordered_set<std::string> kOffsetWindowFunctions = {"nth_value", "lag"};

struct SplitInfo {
/// Whether the split comes from arrow array stream node.
Expand Down
2 changes: 1 addition & 1 deletion cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo
case ::substrait::Expression::RexTypeCase::kLiteral:
break;
default:
LOG_VALIDATION_MSG("Only field is supported in window functions.");
LOG_VALIDATION_MSG("Only field or constant is supported in window functions.");
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.hive.HiveTableScanExecTransformer
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{BooleanType, LongType, NullType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

import java.lang.{Long => JLong}
Expand Down Expand Up @@ -501,32 +501,39 @@ trait SparkPlanExecApi {
)
windowExpressionNodes.add(windowFunctionNode)
case wf @ (Lead(_, _, _, _) | Lag(_, _, _, _)) =>
val offset_wf = wf.asInstanceOf[FrameLessOffsetWindowFunction]
val frame = offset_wf.frame.asInstanceOf[SpecifiedWindowFrame]
val offsetWf = wf.asInstanceOf[FrameLessOffsetWindowFunction]
val frame = offsetWf.frame.asInstanceOf[SpecifiedWindowFrame]
val childrenNodeList = new JArrayList[ExpressionNode]()
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offset_wf.input,
attributeSeq = originalInputAttributes)
.doTransform(args))
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offset_wf.offset,
offsetWf.input,
attributeSeq = originalInputAttributes)
.doTransform(args))
// Spark only accepts foldable offset. Converts it to LongType literal.
val offsetNode = ExpressionBuilder.makeLiteral(
// Velox always expects positive offset.
Math.abs(offsetWf.offset.eval(EmptyRow).asInstanceOf[Int].toLong),
LongType,
false)
childrenNodeList.add(offsetNode)
// NullType means Null is the default value. Don't pass it to native.
if (offsetWf.default.dataType != NullType) {
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offsetWf.default,
attributeSeq = originalInputAttributes)
.doTransform(args))
}
// Always adds ignoreNulls at the end.
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
offset_wf.default,
attributeSeq = originalInputAttributes)
.doTransform(args))
ExpressionBuilder.makeLiteral(offsetWf.ignoreNulls, BooleanType, false))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, offset_wf).toInt,
WindowFunctionsBuilder.create(args, offsetWf).toInt,
childrenNodeList,
columnName,
ConverterUtils.getTypeNode(offset_wf.dataType, offset_wf.nullable),
ConverterUtils.getTypeNode(offsetWf.dataType, offsetWf.nullable),
WindowExecTransformer.getFrameBound(frame.upper),
WindowExecTransformer.getFrameBound(frame.lower),
frame.frameType.sql
Expand All @@ -540,7 +547,8 @@ trait SparkPlanExecApi {
.replaceWithExpressionTransformer(input, attributeSeq = originalInputAttributes)
.doTransform(args))
childrenNodeList.add(LiteralTransformer(offset).doTransform(args))
childrenNodeList.add(LiteralTransformer(Literal(ignoreNulls)).doTransform(args))
// Always adds ignoreNulls at the end.
childrenNodeList.add(ExpressionBuilder.makeLiteral(ignoreNulls, BooleanType, false))
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, wf).toInt,
childrenNodeList,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan]
case p: ShuffledHashJoinExec =>
tagNotTransformable(p.withNewChildren(p.children.map(tagNotTransformableRecursive)))
case p if !supportCodegen(p) =>
// insert row guard them recursively
p.withNewChildren(p.children.map(tagNotTransformableForMultiCodegens))
case p if isAQEShuffleReadExec(p) =>
p.withNewChildren(p.children.map(tagNotTransformableForMultiCodegens))
Expand Down

0 comments on commit 65fb100

Please sign in to comment.