Skip to content

Commit

Permalink
[GLUTEN-3644][CH] Revert the logic to support the custom aggregate fu…
Browse files Browse the repository at this point in the history
…nctions

In PR #3629, it removes the logic to support the custom aggregate functions, must be reverted.

Close #3644.
  • Loading branch information
zzcclp committed Nov 8, 2023
1 parent b60fe75 commit b06c9ae
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package io.glutenproject.execution.extension
import io.glutenproject.expression._
import io.glutenproject.extension.ExpressionExtensionTrait

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._

import scala.collection.mutable.ListBuffer

case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {

lazy val expressionSigs = Seq(
Expand All @@ -29,4 +32,42 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait {

/** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */
override def expressionSigList: Seq[Sig] = expressionSigs

/** Get the attribute index of the extension aggregate functions. */
override def getAttrsIndexForExtensionAggregateExpr(
aggregateFunc: AggregateFunction,
mode: AggregateMode,
exp: AggregateExpression,
aggregateAttributeList: Seq[Attribute],
aggregateAttr: ListBuffer[Attribute],
resIndex: Int): Int = {
var reIndex = resIndex
aggregateFunc match {
case CustomSum(_, _) =>
mode match {
// custom logic: can not support 'PartialMerge'
case Partial =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
if (aggBufferAttr.size == 2) {
// decimal sum check sum.resultType
aggregateAttr += ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
val isEmptyAttr = ConverterUtils.getAttrFromExpr(aggBufferAttr(1))
aggregateAttr += isEmptyAttr
reIndex += 2
reIndex
} else {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
aggregateAttr += attr
reIndex += 1
reIndex
}
case Final =>
aggregateAttr += aggregateAttributeList(reIndex)
reIndex += 1
reIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -408,25 +408,40 @@ abstract class HashAggregateExecBaseTransformer(
var resIndex = index
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
if (!checkAggFuncModeSupport(aggregateFunc, mode)) {
throw new UnsupportedOperationException(
s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}")
}
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += aggBufferAttr.size
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
// First handle the custom aggregate functions
if (
ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
aggregateFunc.getClass)
) {
ExpressionMappings.expressionExtensionTransformer
.getAttrsIndexForExtensionAggregateExpr(
aggregateFunc,
mode,
exp,
aggregateAttributeList,
aggregateAttr,
index)
} else {
if (!checkAggFuncModeSupport(aggregateFunc, mode)) {
throw new UnsupportedOperationException(
s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}")
}
mode match {
case Partial | PartialMerge =>
val aggBufferAttr = aggregateFunc.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
aggregateAttr += attr
}
resIndex += aggBufferAttr.size
resIndex
case Final =>
aggregateAttr += aggregateAttributeList(resIndex)
resIndex += 1
resIndex
case other =>
throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import io.glutenproject.expression.{ExpressionTransformer, Sig}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, AggregateMode}

import scala.collection.mutable.ListBuffer

trait ExpressionExtensionTrait {

Expand All @@ -37,6 +40,18 @@ trait ExpressionExtensionTrait {
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
throw new UnsupportedOperationException(s"${expr.getClass} or $expr is not supported.")
}

/** Get the attribute index of the extension aggregate functions. */
def getAttrsIndexForExtensionAggregateExpr(
aggregateFunc: AggregateFunction,
mode: AggregateMode,
exp: AggregateExpression,
aggregateAttributeList: Seq[Attribute],
aggregateAttr: ListBuffer[Attribute],
resIndex: Int): Int = {
throw new UnsupportedOperationException(
s"Aggregate function ${aggregateFunc.getClass} is not supported.")
}
}

case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {
Expand Down

0 comments on commit b06c9ae

Please sign in to comment.