diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 50f7143145a94..f97725655995b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -712,6 +712,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARTIAL_STAGE_NOT_COALESCE_PARTITIONS_ENABLED = + buildConf("spark.sql.adaptive.partial.stage.not.coalescePartitions.enabled") + .doc(s"Partial stages do not merge partitions. For example, if the stage includes Expand, it will not merge, and merging may cause the stage to run too slowly.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + + val COALESCE_PARTITIONS_PARALLELISM_FIRST = buildConf("spark.sql.adaptive.coalescePartitions.parallelismFirst") .doc("When true, Spark does not respect the target size specified by " + @@ -5533,6 +5541,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def coalesceShufflePartitionsEnabled: Boolean = getConf(COALESCE_PARTITIONS_ENABLED) + def partialStageNotcoalesceShufflePartitionsEnabled: Boolean = + getConf(PARTIAL_STAGE_NOT_COALESCE_PARTITIONS_ENABLED) + def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) def ratioExtraSpaceAllowedInCheckpoint: Double = getConf(RATIO_EXTRA_SPACE_ALLOWED_IN_CHECKPOINT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index a0a0991429309..fb158d94d887e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -545,6 +545,37 @@ case class AdaptiveSparkPlanExec( this.inputPlan == obj.asInstanceOf[AdaptiveSparkPlanExec].inputPlan } + private def containsExpandExec(plan: SparkPlan): Boolean = { + def traverse(plan: SparkPlan): Boolean = { + plan match { + case _: ShuffleQueryStageExec => + false + case _: ExpandExec => + true + case _ => + plan.children.exists(traverse) + } + } + + traverse(plan) + } + + private def findAndApplyToAllShuffleExchanges(plan: SparkPlan): SparkPlan = { + def traverseAndModify(plan: SparkPlan): SparkPlan = { + plan match { + case shuffleExec: ShuffleQueryStageExec => + shuffleExec.setUseCoalesceShufflePartitions(false) + shuffleExec + case _ => + val modifiedChildren = plan.children.map(traverseAndModify) + plan.withNewChildren(modifiedChildren) + } + } + + traverseAndModify(plan) + } + + /** * This method is called recursively to traverse the plan tree bottom-up and create a new query * stage or try reusing an existing stage if the current node is an [[Exchange]] node and all of @@ -569,7 +600,10 @@ case class AdaptiveSparkPlanExec( case _ => val result = createQueryStages(e.child) - val newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange] + var newPlan = e.withNewChildren(Seq(result.newPlan)).asInstanceOf[Exchange] + if (conf.partialStageNotcoalesceShufflePartitionsEnabled && containsExpandExec(newPlan)) { + newPlan = findAndApplyToAllShuffleExchanges(newPlan).asInstanceOf[Exchange] + } // Create a query stage only when all the child query stages are ready. if (result.allChildStagesMaterialized) { var newStage = newQueryStage(newPlan).asInstanceOf[ExchangeQueryStageExec] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index 1bbc26f3e52ed..333dcb1f75138 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -45,6 +45,12 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe if (!conf.coalesceShufflePartitionsEnabled) { return plan } + if (conf.partialStageNotcoalesceShufflePartitionsEnabled) { + val shuffleStageInfos = collectShuffleStageInfos(plan) + if (!shuffleStageInfos.forall(s => s.shuffleStage.getUseCoalesceShufflePartitions)) { + return plan + } + } // Ideally, this rule should simply coalesce partitions w.r.t. the target size specified by // ADVISORY_PARTITION_SIZE_IN_BYTES (default 64MB). To avoid perf regression in AQE, this diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 2391fe740118d..342fff7fbaaa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -203,6 +203,14 @@ case class ShuffleQueryStageExec( override protected def doMaterialize(): Future[Any] = shuffle.submitShuffleJob() + private var useCoalesceShufflePartitions: Boolean = true + + def getUseCoalesceShufflePartitions: Boolean = useCoalesceShufflePartitions + + def setUseCoalesceShufflePartitions(useCoalesceShufflePartitions: Boolean): Unit = { + this.useCoalesceShufflePartitions = useCoalesceShufflePartitions + } + override def newReuseInstance( newStageId: Int, newOutput: Seq[Attribute]): ExchangeQueryStageExec = { val reuse = ShuffleQueryStageExec( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala index 9ed4f1a006b2b..97a1a54c06861 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala @@ -373,6 +373,19 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with SQLConfHelper { withSparkSession(test, 100, None) } +// test("SPARK-50257 If a stage contains ExpandExec, the CoalesceShufflePartitions rule " + +// "will not be adjusted during the AQE phase") { +// val test: SparkSession => Unit = { spark: SparkSession => +// withSQLConf( +// ("spark.sql.adaptive.partial.stage.not.coalescePartitions.enabled" -> "true"), +// ("spark.sql.shuffle.partitions" -> 1000) +// ) { +// +// +// } +// } +// } + test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { val test: SparkSession => Unit = { spark: SparkSession => withSQLConf("spark.sql.exchange.reuse" -> "true") {