Skip to content

Commit

Permalink
Get input offset for Spark aggregate function
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 4, 2024
1 parent 3c7452f commit 5d4c1b2
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
3 changes: 3 additions & 0 deletions core/src/execution/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ message AggExpr {
Stddev stddev = 15;
Correlation correlation = 16;
}

// The offset to input batch.
int32 input_offset = 17;
}

enum StatisticsType {
Expand Down
23 changes: 20 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
def aggExprToProto(
aggExpr: AggregateExpression,
inputs: Seq[Attribute],
binding: Boolean): Option[AggExpr] = {
binding: Boolean,
inputOffset: Int): Option[AggExpr] = {
aggExpr.aggregateFunction match {
case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) && isLegacyMode(s) =>
val childExpr = exprToProto(child, inputs, binding)
Expand Down Expand Up @@ -331,6 +332,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
val firstBuilder = ExprOuterClass.First.newBuilder()
firstBuilder.setChild(childExpr.get)
firstBuilder.setDatatype(dataType.get)
firstBuilder.setInputOffset(inputOffset)

Some(
ExprOuterClass.AggExpr
Expand All @@ -353,6 +355,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
val lastBuilder = ExprOuterClass.Last.newBuilder()
lastBuilder.setChild(childExpr.get)
lastBuilder.setDatatype(dataType.get)
lastBuilder.setInputOffset(inputOffset)

Some(
ExprOuterClass.AggExpr
Expand Down Expand Up @@ -2399,8 +2402,22 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
// `output` is only used when `binding` is true (i.e., non-Final)
val output = child.output

val aggExprs =
aggregateExpressions.map(aggExprToProto(_, output, binding))
// In DataFusion some aggregate expressions are implemented using AggregateUDF.
// For such expressions, we need to extract its input expressions manually in native
// code (see planner.rs). To do that, due to the limitation of DataFusion AggregateUDF
// design, we have to get the input offset for each aggregate expression.
val aggExprs = if (mode == CometAggregateMode.Final) {
aggregateExpressions.map { expr =>
val firstAttr = expr.aggregateFunction.inputAggBufferAttributes.head
val offset = output.indexWhere(_.exprId == firstAttr.exprId)
// scalastyle:off println
println(s"Aggregate expression: $expr, offset: $offset")
aggregateExpressions.map(aggExprToProto(_, output, binding, offset))
}
} else {
aggregateExpressions.map(aggExprToProto(_, output, binding, -1))
}

if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
aggExprs.forall(_.isDefined)) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,33 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("first") {
withSQLConf(
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_ENFORCE_MODE_ENABLED.key -> "true",
CometConf.COMET_SHUFFLE_MODE.key -> "jvm") {
Seq(true, false).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "test"
withTable(table) {
sql(s"create table $table(col1 int, col2 int, col3 int) using parquet")
sql(
s"insert into $table values(4, 1, 1), (4, 1, 1), (3, 3, 1)," +
" (2, 4, 2), (1, 3, 2), (null, 1, 1)")
withView("t") {
sql("CREATE VIEW t AS SELECT col1, col3 FROM test ORDER BY col1")

val df = sql("SELECT FIRST(col1) FROM t")
df.explain()
df.collect()
}
}
}
}
}
}

protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)
Expand Down

0 comments on commit 5d4c1b2

Please sign in to comment.