diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala index 706cc5f34108..694035b878a5 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/utils/RangePartitionerBoundsGenerator.scala @@ -210,28 +210,31 @@ class RangePartitionerBoundsGenerator[K: Ordering: ClassTag, V]( arrayNode } - private def buildRangeBoundsJson(jsonMapper: ObjectMapper, arrayNode: ArrayNode): Unit = { + private def buildRangeBoundsJson(jsonMapper: ObjectMapper, arrayNode: ArrayNode): Int = { val bounds = getRangeBounds bounds.foreach { bound => val row = bound.asInstanceOf[UnsafeRow] arrayNode.add(buildRangeBoundJson(row, ordering, jsonMapper)) } + bounds.length } // Make a json structure that can be passed to native engine - def getRangeBoundsJsonString: String = { + def getRangeBoundsJsonString: RangeBoundsInfo = { val context = new SubstraitContext() val mapper = new ObjectMapper val rootNode = mapper.createObjectNode val orderingArray = rootNode.putArray("ordering") buildOrderingJson(context, ordering, inputAttributes, mapper, orderingArray) val boundArray = rootNode.putArray("range_bounds") - buildRangeBoundsJson(mapper, boundArray) - mapper.writeValueAsString(rootNode) + val boundLength = buildRangeBoundsJson(mapper, boundArray) + RangeBoundsInfo(mapper.writeValueAsString(rootNode), boundLength) } } +case class RangeBoundsInfo(json: String, boundsSize: Int) + object RangePartitionerBoundsGenerator { def supportedFieldType(dataType: DataType): Boolean = { dataType match { diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala index cc172ac4b543..83bde9d168fd 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/utils/CHExecUtil.scala @@ -311,7 +311,7 @@ object CHExecUtil extends Logging { rddForSampling, sortingExpressions, childOutputAttributes) - val orderingAndRangeBounds = generator.getRangeBoundsJsonString + val rangeBoundsInfo = generator.getRangeBoundsJsonString val attributePos = if (projectOutputAttributes != null) { projectOutputAttributes.map( attr => @@ -324,10 +324,11 @@ object CHExecUtil extends Logging { } new NativePartitioning( GlutenShuffleUtils.RangePartitioningShortName, - numPartitions, + rangeBoundsInfo.boundsSize + 1, Array.empty[Byte], - orderingAndRangeBounds.getBytes(), - attributePos.mkString(",").getBytes) + rangeBoundsInfo.json.getBytes, + attributePos.mkString(",").getBytes + ) case p => throw new IllegalStateException(s"Unknow partition type: ${p.getClass.toString}") } @@ -368,7 +369,7 @@ object CHExecUtil extends Logging { val dependency = new ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( rddWithPartitionKey, - new PartitionIdPassthrough(newPartitioning.numPartitions), + new PartitionIdPassthrough(nativePartitioning.getNumPartitions), serializer, shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics), nativePartitioning = nativePartitioning, diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index b43fc2625f0b..f837484b0ca8 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -57,7 +57,7 @@ class GlutenClickHouseColumnarShuffleAQESuite val coalescedPartitionSpec0 = colCustomShuffleReaderExecs.head.partitionSpecs.head .asInstanceOf[CoalescedPartitionSpec] assert(coalescedPartitionSpec0.startReducerIndex == 0) - assert(coalescedPartitionSpec0.endReducerIndex == 5) + assert(coalescedPartitionSpec0.endReducerIndex == 4) val coalescedPartitionSpec1 = colCustomShuffleReaderExecs(1).partitionSpecs.head .asInstanceOf[CoalescedPartitionSpec] assert(coalescedPartitionSpec1.startReducerIndex == 0) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala index 1a276a26b258..7a9abd3bad6f 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseJoinSuite.scala @@ -21,8 +21,11 @@ import org.apache.gluten.backendsapi.clickhouse.CHConf import org.apache.gluten.utils.UTSystemParameters import org.apache.spark.SparkConf +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig +import java.util.concurrent.atomic.AtomicInteger + class GlutenClickHouseJoinSuite extends GlutenClickHouseWholeStageTransformerSuite { protected val tablesPath: String = basePath + "/tpch-data" @@ -141,4 +144,38 @@ class GlutenClickHouseJoinSuite extends GlutenClickHouseWholeStageTransformerSui sql("drop table if exists tj2") } + test("GLUTEN-8216 Fix OOM when cartesian product with empty data") { + // prepare + spark.sql("create table test_join(a int, b int, c int) using parquet") + var overrideConfs = Map( + "spark.sql.autoBroadcastJoinThreshold" -> "-1", + "spark.sql.shuffle.partitions" -> "1" + ) + if (isSparkVersionGE("3.5")) { + // Range partitions will not be reduced if EliminateSorts is enabled in spark35. + overrideConfs += "spark.sql.optimizer.excludedRules" -> + "org.apache.spark.sql.catalyst.optimizer.EliminateSorts" + } + + withSQLConf(overrideConfs.toSeq: _*) { + val taskCount = new AtomicInteger(0) + val taskListener = new SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + taskCount.incrementAndGet() + logDebug(s"Task ${taskEnd.taskInfo.id} finished. Total tasks completed: $taskCount") + } + } + spark.sparkContext.addSparkListener(taskListener) + spark + .sql( + "select * from " + + "(select a from test_join group by a order by a), " + + "(select b from test_join group by b order by b)" + + " limit 10000" + ) + .collect() + assert(taskCount.get() < 500) + } + } + } diff --git a/gluten-celeborn/clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseRSSColumnarShuffleAQESuite.scala b/gluten-celeborn/clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseRSSColumnarShuffleAQESuite.scala index 00f3bee8eb7b..e62dbdd2a5fe 100644 --- a/gluten-celeborn/clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseRSSColumnarShuffleAQESuite.scala +++ b/gluten-celeborn/clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseRSSColumnarShuffleAQESuite.scala @@ -62,7 +62,7 @@ class GlutenClickHouseRSSColumnarShuffleAQESuite .partitionSpecs(0) .asInstanceOf[CoalescedPartitionSpec] assert(coalescedPartitionSpec0.startReducerIndex == 0) - assert(coalescedPartitionSpec0.endReducerIndex == 5) + assert(coalescedPartitionSpec0.endReducerIndex == 4) val coalescedPartitionSpec1 = colCustomShuffleReaderExecs(1) .partitionSpecs(0) .asInstanceOf[CoalescedPartitionSpec]