Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 committed Aug 2, 2024
1 parent 0753524 commit 416c801
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.gluten.extension
import org.apache.gluten.GlutenConfig
import org.apache.gluten.expression.ExpressionMappings
import org.apache.gluten.expression.aggregate.{VeloxCollectList, VeloxCollectSet}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.LogicalPlanSelector

import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -55,10 +56,10 @@ case class CollectRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {
case PhysicalAggregation(_, aggregateExpr, _, _)
if !GlutenConfig.getConf.veloxObjectHashAggCollectRewriteEnabled =>
val aggregateExpressions = aggregateExpr.map(expr => expr.asInstanceOf[AggregateExpression])
val useHash = Aggregate.supportsHashAggregate(
val useHash = SparkShimLoader.getSparkShims.supportsHashAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
val useObjectHash = plan.conf.useObjectHashAggregation &&
Aggregate.supportsObjectHashAggregate(aggregateExpressions)
SparkShimLoader.getSparkShims.supportsObjectHashAggregate(aggregateExpressions)
useHash || !useObjectHash
case _ => true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryExpression, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -267,4 +267,8 @@ trait SparkShims {
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
}

def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean

def supportsObjectHashAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution}
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil, SparkPlan}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.FileFormatWriter.Empty2Null
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters
Expand Down Expand Up @@ -284,4 +285,13 @@ class Spark32Shims extends SparkShims {
val s = decimalType.scale
DecimalType(p, if (toScale > s) s else toScale)
}

override def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
HashAggregateExec.supportsAggregate(aggregateBufferAttributes)
}

override def supportsObjectHashAggregate(
aggregateExpressions: Seq[AggregateExpression]): Boolean = {
ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, RegrR2, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate, RegrR2, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
Expand Down Expand Up @@ -365,4 +365,13 @@ class Spark33Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}

override def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
Aggregate.supportsHashAggregate(aggregateBufferAttributes)
}

override def supportsObjectHashAggregate(
aggregateExpressions: Seq[AggregateExpression]): Boolean = {
Aggregate.supportsObjectHashAggregate(aggregateExpressions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, InternalRowComparableWrapper, TimestampFormatter}
Expand Down Expand Up @@ -493,4 +493,13 @@ class Spark34Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}

override def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
Aggregate.supportsHashAggregate(aggregateBufferAttributes)
}

override def supportsObjectHashAggregate(
aggregateExpressions: Seq[AggregateExpression]): Boolean = {
Aggregate.supportsObjectHashAggregate(aggregateExpressions)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, KeyGroupedPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -518,4 +518,13 @@ class Spark35Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}

override def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
Aggregate.supportsHashAggregate(aggregateBufferAttributes)
}

override def supportsObjectHashAggregate(
aggregateExpressions: Seq[AggregateExpression]): Boolean = {
Aggregate.supportsObjectHashAggregate(aggregateExpressions)
}
}

0 comments on commit 416c801

Please sign in to comment.