Skip to content

Commit

Permalink
[GLUTEN-3486][CH] Fix AQE cannot coalesce shuffle partitions (#3941)
Browse files Browse the repository at this point in the history
Fix AQE cannot coalesce shuffle partitions
  • Loading branch information
exmy authored Dec 8, 2023
1 parent 1b87ae9 commit 2ccea31
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.ShuffleDependency
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ColumnarShuffleDependency, GlutenShuffleUtils}
import org.apache.spark.shuffle.{ColumnarShuffleDependency, GlutenShuffleUtils, HashPartitioningWrapper}
import org.apache.spark.shuffle.utils.RangePartitionerBoundsGenerator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, BoundReference, UnsafeProjection, UnsafeRow}
Expand Down Expand Up @@ -175,11 +175,17 @@ object CHExecUtil extends Logging {
}

private def buildHashPartitioning(
partitoining: HashPartitioning,
partitioning: HashPartitioning,
childOutput: Seq[Attribute],
output: Seq[Attribute]): NativePartitioning = {
val hashExpressions = partitioning match {
case partitioning: HashPartitioningWrapper =>
partitioning.getNewExpr
case _ =>
partitioning.expressions
}
val hashFields =
partitoining.expressions.map(
hashExpressions.map(
a =>
BindReferences
.bindReference(ConverterUtils.getAttrFromExpr(a).toAttribute, childOutput)
Expand All @@ -199,7 +205,7 @@ object CHExecUtil extends Logging {

new NativePartitioning(
GlutenShuffleUtils.HashPartitioningShortName,
partitoining.numPartitions,
partitioning.numPartitions,
Array.empty[Byte],
hashFields.mkString(",").getBytes(),
outputFields.mkString(",").getBytes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import io.glutenproject.utils.{LogLevelUtil, PhysicalPlanSelector}

import org.apache.spark.api.python.EvalPythonExecTransformer
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.HashPartitioningWrapper
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BindReferences, BoundReference, Expression, Murmur3Hash, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
Expand Down Expand Up @@ -204,7 +205,10 @@ case class TransformPreOverrides(isAdaptiveContext: Boolean)
val pos = newExpressionsPosition(i)
newExprs = newExprs :+ project.output(pos)
}
(projectExpressions.size, HashPartitioning(newExprs, numPartitions), project)
(
projectExpressions.size,
new HashPartitioningWrapper(exprs, newExprs, numPartitions),
project)
case RangePartitioning(orderings, numPartitions) =>
val exprs = orderings.map(ordering => ordering.child)
val (projectExpressions, newExpressionsPosition) = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning

// A wrapper for HashPartitioning to remain original hash expressions.
// Only used by CH backend when shuffle hash expressions contains non-field expression.
class HashPartitioningWrapper(
original: Seq[Expression],
newExpr: Seq[Expression],
override val numPartitions: Int)
extends HashPartitioning(original, numPartitions) {

def getNewExpr: Seq[Expression] = newExpr
}

0 comments on commit 2ccea31

Please sign in to comment.