diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala index 84ad585f7311..757795bf8634 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/extension/CustomAggExpressionTransformer.scala @@ -19,11 +19,8 @@ package io.glutenproject.execution.extension import io.glutenproject.expression._ import io.glutenproject.extension.ExpressionExtensionTrait -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import scala.collection.mutable.ListBuffer - case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait { lazy val expressionSigs = Seq( @@ -32,41 +29,4 @@ case class CustomAggExpressionTransformer() extends ExpressionExtensionTrait { /** Generate the extension expressions list, format: Sig[XXXExpression]("XXXExpressionName") */ override def expressionSigList: Seq[Sig] = expressionSigs - - /** Get the attribute index of the extension aggregate functions. */ - override def getAttrsForExtensionAggregateExpr( - aggregateFunc: AggregateFunction, - mode: AggregateMode, - exp: AggregateExpression, - aggregateAttributeList: Seq[Attribute], - aggregateAttr: ListBuffer[Attribute], - resIndex: Int): Int = { - var reIndex = resIndex - aggregateFunc match { - case CustomSum(_, _) => - mode match { - case Partial | PartialMerge => - val aggBufferAttr = aggregateFunc.inputAggBufferAttributes - if (aggBufferAttr.size == 2) { - // decimal sum check sum.resultType - aggregateAttr += ConverterUtils.getAttrFromExpr(aggBufferAttr.head) - val isEmptyAttr = ConverterUtils.getAttrFromExpr(aggBufferAttr(1)) - aggregateAttr += isEmptyAttr - reIndex += 2 - reIndex - } else { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) - aggregateAttr += attr - reIndex += 1 - reIndex - } - case Final => - aggregateAttr += aggregateAttributeList(reIndex) - reIndex += 1 - reIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") - } - } - } } diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index 49f80ab119bb..68b55501424e 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -37,7 +37,6 @@ import com.google.protobuf.Any import java.util import scala.collection.JavaConverters._ -import scala.collection.mutable.ListBuffer case class HashAggregateExecTransformer( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -56,34 +55,18 @@ case class HashAggregateExecTransformer( resultExpressions, child) { - override protected def getAttrForAggregateExpr( - exp: AggregateExpression, - aggregateAttributeList: Seq[Attribute], - aggregateAttr: ListBuffer[Attribute], - index: Int): Int = { - var resIndex = index - val mode = exp.mode - val aggregateFunc = exp.aggregateFunction - aggregateFunc match { - case hllAdapter: HLLAdapter => + override protected def checkAggFuncModeSupport( + aggFunc: AggregateFunction, + mode: AggregateMode): Boolean = { + aggFunc match { + case _: HLLAdapter => mode match { - case Partial => - val aggBufferAttr = hllAdapter.inputAggBufferAttributes - for (index <- aggBufferAttr.indices) { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) - aggregateAttr += attr - } - resIndex += aggBufferAttr.size - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - case other => - throw new UnsupportedOperationException(s"not currently supported: $other.") + case Partial | Final => true + case _ => false } case _ => - resIndex = super.getAttrForAggregateExpr(exp, aggregateAttributeList, aggregateAttr, index) + super.checkAggFuncModeSupport(aggFunc, mode) } - resIndex } override protected def withNewChildInternal(newChild: SparkPlan): HashAggregateExecTransformer = { diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index 91844652938f..1a5fb97dd945 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate._ import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.sketch.BloomFilter import com.google.protobuf.Any @@ -417,195 +416,47 @@ abstract class HashAggregateExecBaseTransformer( var resIndex = index val mode = exp.mode val aggregateFunc = exp.aggregateFunction - aggregateFunc match { - case extendedAggFunc - if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( - extendedAggFunc.getClass) => - // get attributes from the extended aggregate functions - ExpressionMappings.expressionExtensionTransformer - .getAttrsForExtensionAggregateExpr( - aggregateFunc, - mode, - exp, - aggregateAttributeList, - aggregateAttr, - index) - case _: Average | _: First | _: Last => - mode match { - case Partial | PartialMerge => - val aggBufferAttr = aggregateFunc.inputAggBufferAttributes - for (index <- aggBufferAttr.indices) { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) - aggregateAttr += attr - } - resIndex += 2 - resIndex - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") - } - case Sum(_, _) => - mode match { - case Partial | PartialMerge => - val sum = aggregateFunc.asInstanceOf[Sum] - val aggBufferAttr = sum.inputAggBufferAttributes - if (aggBufferAttr.size == 2) { - // decimal sum check sum.resultType - aggregateAttr += ConverterUtils.getAttrFromExpr(aggBufferAttr.head) - val isEmptyAttr = ConverterUtils.getAttrFromExpr(aggBufferAttr(1)) - aggregateAttr += isEmptyAttr - resIndex += 2 - resIndex - } else { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) - aggregateAttr += attr - resIndex += 1 - resIndex - } - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") - } - case Count(_) => - mode match { - case Partial | PartialMerge => - val count = aggregateFunc.asInstanceOf[Count] - val aggBufferAttr = count.inputAggBufferAttributes - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) - aggregateAttr += attr - resIndex += 1 - resIndex - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") - } - case _: Max | _: Min | _: BitAndAgg | _: BitOrAgg | _: BitXorAgg => - mode match { - case Partial | PartialMerge => - val aggBufferAttr = aggregateFunc.inputAggBufferAttributes - assert( - aggBufferAttr.size == 1, - s"Aggregate function $aggregateFunc expects one buffer attribute.") - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) - aggregateAttr += attr - resIndex += 1 - resIndex - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") - } - case _: Corr => - mode match { - case Partial | PartialMerge => - val expectedBufferSize = 6 - val aggBufferAttr = aggregateFunc.inputAggBufferAttributes - assert( - aggBufferAttr.size == expectedBufferSize, - s"Aggregate function $aggregateFunc" + - s" expects $expectedBufferSize buffer attribute.") - for (index <- aggBufferAttr.indices) { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) - aggregateAttr += attr - } - resIndex += expectedBufferSize - resIndex - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") - } - case _: CovPopulation | _: CovSample => - mode match { - case Partial | PartialMerge => - val expectedBufferSize = 4 - val aggBufferAttr = aggregateFunc.inputAggBufferAttributes - assert( - aggBufferAttr.size == expectedBufferSize, - s"Aggregate function $aggregateFunc" + - s" expects $expectedBufferSize buffer attributes.") - for (index <- aggBufferAttr.indices) { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) - aggregateAttr += attr - } - resIndex += expectedBufferSize - resIndex - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + if (!checkAggFuncModeSupport(aggregateFunc, mode)) { + throw new UnsupportedOperationException( + s"Unsupported aggregate mode: $mode for ${aggregateFunc.prettyName}") + } + mode match { + case Partial | PartialMerge => + val aggBufferAttr = aggregateFunc.inputAggBufferAttributes + for (index <- aggBufferAttr.indices) { + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) + aggregateAttr += attr } - case _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop => + resIndex += aggBufferAttr.size + resIndex + case Final => + aggregateAttr += aggregateAttributeList(resIndex) + resIndex += 1 + resIndex + case other => + throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + } + } + + protected def checkAggFuncModeSupport( + aggFunc: AggregateFunction, + mode: AggregateMode): Boolean = { + aggFunc match { + case _: CollectList | _: CollectSet => mode match { - case Partial | PartialMerge => - val aggBufferAttr = aggregateFunc.inputAggBufferAttributes - for (index <- aggBufferAttr.indices) { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) - aggregateAttr += attr - } - resIndex += 3 - resIndex - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + case Partial | Final => true + case _ => false } case bloom if bloom.getClass.getSimpleName.equals("BloomFilterAggregate") => - // for spark33 mode match { - case Partial => - val bloom = aggregateFunc.asInstanceOf[TypedImperativeAggregate[BloomFilter]] - val aggBufferAttr = bloom.inputAggBufferAttributes - for (index <- aggBufferAttr.indices) { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) - aggregateAttr += attr - } - resIndex += aggBufferAttr.size - resIndex - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + case Partial | Final => true + case _ => false } - case _: CollectList | _: CollectSet => + case _ => mode match { - case Partial => - val aggBufferAttr = aggregateFunc.inputAggBufferAttributes - for (index <- aggBufferAttr.indices) { - val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) - aggregateAttr += attr - } - resIndex += aggBufferAttr.size - resIndex - case Final => - aggregateAttr += aggregateAttributeList(resIndex) - resIndex += 1 - resIndex - case other => - throw new UnsupportedOperationException(s"Unsupported aggregate mode: $other.") + case Partial | PartialMerge | Final => true + case _ => false } - case other => - throw new UnsupportedOperationException( - s"Unsupported aggregate function in getAttrForAggregateExpr") } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala index f3ec1b166477..5a5fb08c3c91 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ExpressionExtensionTrait.scala @@ -20,9 +20,6 @@ import io.glutenproject.expression.{ExpressionTransformer, Sig} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, AggregateMode} - -import scala.collection.mutable.ListBuffer trait ExpressionExtensionTrait { @@ -40,18 +37,6 @@ trait ExpressionExtensionTrait { attributeSeq: Seq[Attribute]): ExpressionTransformer = { throw new UnsupportedOperationException(s"${expr.getClass} or $expr is not supported.") } - - /** Get the attribute index of the extension aggregate functions. */ - def getAttrsForExtensionAggregateExpr( - aggregateFunc: AggregateFunction, - mode: AggregateMode, - exp: AggregateExpression, - aggregateAttributeList: Seq[Attribute], - aggregateAttr: ListBuffer[Attribute], - resIndex: Int): Int = { - throw new UnsupportedOperationException( - s"Aggregate function ${aggregateFunc.getClass} is not supported.") - } } case class DefaultExpressionExtensionTransformer() extends ExpressionExtensionTrait with Logging {