diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocators.java b/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocators.java index 01fb7e3e2f7cf..0f30972fcd44d 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocators.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/memory/alloc/CHNativeMemoryAllocators.java @@ -19,13 +19,12 @@ import org.apache.gluten.memory.SimpleMemoryUsageRecorder; import org.apache.gluten.memory.memtarget.MemoryTargets; import org.apache.gluten.memory.memtarget.Spiller; +import org.apache.gluten.memory.memtarget.Spillers; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.TaskResources; -import java.util.Arrays; import java.util.Collections; -import java.util.List; /** * Built-in toolkit for managing native memory allocations. To use the facility, one should import @@ -46,12 +45,12 @@ private CHNativeMemoryAllocators() {} private static CHNativeMemoryAllocatorManager createNativeMemoryAllocatorManager( String name, TaskMemoryManager taskMemoryManager, - List spillers, + Spiller spiller, SimpleMemoryUsageRecorder usage) { CHManagedCHReservationListener rl = new CHManagedCHReservationListener( - MemoryTargets.newConsumer(taskMemoryManager, name, spillers, Collections.emptyMap()), + MemoryTargets.newConsumer(taskMemoryManager, name, spiller, Collections.emptyMap()), usage); return new CHNativeMemoryAllocatorManagerImpl(CHNativeMemoryAllocator.createListenable(rl)); } @@ -67,7 +66,7 @@ public static CHNativeMemoryAllocator contextInstance() { createNativeMemoryAllocatorManager( "ContextInstance", TaskResources.getLocalTaskContext().taskMemoryManager(), - Collections.emptyList(), + Spillers.NOOP, TaskResources.getSharedUsage()); TaskResources.addResource(id, manager); } @@ -78,7 +77,7 @@ public static CHNativeMemoryAllocator contextInstanceForUT() { return CHNativeMemoryAllocator.getDefaultForUT(); } - public static CHNativeMemoryAllocator createSpillable(String name, Spiller... spillers) { + public static CHNativeMemoryAllocator createSpillable(String name, Spiller spiller) { if (!TaskResources.inSparkTask()) { throw new IllegalStateException("spiller must be used in a Spark task"); } @@ -87,7 +86,7 @@ public static CHNativeMemoryAllocator createSpillable(String name, Spiller... sp createNativeMemoryAllocatorManager( name, TaskResources.getLocalTaskContext().taskMemoryManager(), - Arrays.asList(spillers), + spiller, TaskResources.getSharedUsage()); TaskResources.addAnonymousResource(manager); // force add memory consumer to task memory manager, will release by inactivate diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala index 4a1adbec74180..c113f8d4dd319 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/shuffle/CHColumnarShuffleWriter.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SparkDirectoryUtil, Utils} import java.io.IOException -import java.util import java.util.{Locale, UUID} class CHColumnarShuffleWriter[K, V]( @@ -122,7 +121,10 @@ class CHColumnarShuffleWriter[K, V]( CHNativeMemoryAllocators.createSpillable( "ShuffleWriter", new Spiller() { - override def spill(self: MemoryTarget, size: Long): Long = { + override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = { + if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) { + return 0L; + } if (nativeSplitter == 0) { throw new IllegalStateException( "Fatal: spill() called before a shuffle writer " + @@ -134,8 +136,6 @@ class CHColumnarShuffleWriter[K, V]( logError(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data") spilled } - - override def applicablePhases(): util.Set[Spiller.Phase] = Spillers.PHASE_SET_SPILL_ONLY } ) } diff --git a/backends-clickhouse/src/test/java/org/apache/spark/memory/TestTaskMemoryManagerSuite.java b/backends-clickhouse/src/test/java/org/apache/spark/memory/TestTaskMemoryManagerSuite.java index b575de403bfde..905ffacde023d 100644 --- a/backends-clickhouse/src/test/java/org/apache/spark/memory/TestTaskMemoryManagerSuite.java +++ b/backends-clickhouse/src/test/java/org/apache/spark/memory/TestTaskMemoryManagerSuite.java @@ -21,6 +21,7 @@ import org.apache.gluten.memory.alloc.CHNativeMemoryAllocator; import org.apache.gluten.memory.alloc.CHNativeMemoryAllocatorManagerImpl; import org.apache.gluten.memory.memtarget.MemoryTargets; +import org.apache.gluten.memory.memtarget.Spillers; import org.apache.spark.SparkConf; import org.apache.spark.internal.config.package$; @@ -52,7 +53,7 @@ public void initMemoryManager() { listener = new CHManagedCHReservationListener( MemoryTargets.newConsumer( - taskMemoryManager, "test", Collections.emptyList(), Collections.emptyMap()), + taskMemoryManager, "test", Spillers.NOOP, Collections.emptyMap()), new SimpleMemoryUsageRecorder()); manager = new CHNativeMemoryAllocatorManagerImpl(new CHNativeMemoryAllocator(-1L, listener)); diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala index 1c45944e04760..b09e2f11f0db6 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala @@ -50,30 +50,31 @@ case class CachedColumnarBatch( // spotless:off /** * Feature: - * 1. This serializer supports column pruning 2. TODO: support push down filter 3. Super TODO: - * support store offheap object directly + * 1. This serializer supports column pruning + * 2. TODO: support push down filter + * 3. Super TODO: support store offheap object directly * * The data transformation pipeline: * * - Serializer ColumnarBatch -> CachedColumnarBatch - * -> serialize to byte[] + * -> serialize to byte[] * * - Deserializer CachedColumnarBatch -> ColumnarBatch - * -> deserialize to byte[] to create Velox ColumnarBatch + * -> deserialize to byte[] to create Velox ColumnarBatch * * - Serializer InternalRow -> CachedColumnarBatch (support RowToColumnar) - * -> Convert InternalRow to ColumnarBatch - * -> Serializer ColumnarBatch -> CachedColumnarBatch + * -> Convert InternalRow to ColumnarBatch + * -> Serializer ColumnarBatch -> CachedColumnarBatch * * - Serializer InternalRow -> DefaultCachedBatch (unsupport RowToColumnar) - * -> Convert InternalRow to DefaultCachedBatch using vanilla Spark serializer + * -> Convert InternalRow to DefaultCachedBatch using vanilla Spark serializer * * - Deserializer CachedColumnarBatch -> InternalRow (support ColumnarToRow) - * -> Deserializer CachedColumnarBatch -> ColumnarBatch - * -> Convert ColumnarBatch to InternalRow + * -> Deserializer CachedColumnarBatch -> ColumnarBatch + * -> Convert ColumnarBatch to InternalRow * * - Deserializer DefaultCachedBatch -> InternalRow (unsupport ColumnarToRow) - * -> Convert DefaultCachedBatch to InternalRow using vanilla Spark serializer + * -> Convert DefaultCachedBatch to InternalRow using vanilla Spark serializer */ // spotless:on class ColumnarCachedBatchSerializer extends CachedBatchSerializer with SQLConfHelper with Logging { diff --git a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornHashBasedColumnarShuffleWriter.scala b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornHashBasedColumnarShuffleWriter.scala index 524a3ee2e464a..bccd600686b8f 100644 --- a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornHashBasedColumnarShuffleWriter.scala +++ b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornHashBasedColumnarShuffleWriter.scala @@ -16,24 +16,19 @@ */ package org.apache.spark.shuffle +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings import org.apache.gluten.memory.alloc.CHNativeMemoryAllocators -import org.apache.gluten.memory.memtarget.MemoryTarget -import org.apache.gluten.memory.memtarget.Spiller -import org.apache.gluten.memory.memtarget.Spillers +import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers} import org.apache.gluten.vectorized._ - import org.apache.spark._ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.celeborn.client.ShuffleClient -import org.apache.celeborn.common.CelebornConf - import java.io.IOException -import java.util import java.util.Locale class CHCelebornHashBasedColumnarShuffleWriter[K, V]( @@ -43,13 +38,13 @@ class CHCelebornHashBasedColumnarShuffleWriter[K, V]( celebornConf: CelebornConf, client: ShuffleClient, writeMetrics: ShuffleWriteMetricsReporter) - extends CelebornHashBasedColumnarShuffleWriter[K, V]( - shuffleId: Int, - handle, - context, - celebornConf, - client, - writeMetrics) { + extends CelebornHashBasedColumnarShuffleWriter[K, V]( + shuffleId: Int, + handle, + context, + celebornConf, + client, + writeMetrics) { private val customizedCompressCodec = customizedCompressionCodec.toUpperCase(Locale.ROOT) @@ -80,12 +75,14 @@ class CHCelebornHashBasedColumnarShuffleWriter[K, V]( GlutenConfig.getConf.chColumnarThrowIfMemoryExceed, GlutenConfig.getConf.chColumnarFlushBlockBufferBeforeEvict, GlutenConfig.getConf.chColumnarForceExternalSortShuffle, - GlutenConfig.getConf.chColumnarForceMemorySortShuffle - ) + GlutenConfig.getConf.chColumnarForceMemorySortShuffle) CHNativeMemoryAllocators.createSpillable( "CelebornShuffleWriter", new Spiller() { - override def spill(self: MemoryTarget, size: Long): Long = { + override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = { + if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) { + return 0L + } if (nativeShuffleWriter == -1L) { throw new IllegalStateException( "Fatal: spill() called before a celeborn shuffle writer " + @@ -98,10 +95,7 @@ class CHCelebornHashBasedColumnarShuffleWriter[K, V]( logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data") spilled } - - override def applicablePhases(): util.Set[Spiller.Phase] = Spillers.PHASE_SET_SPILL_ONLY - } - ) + }) } while (records.hasNext) { val cb = records.next()._2.asInstanceOf[ColumnarBatch] diff --git a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala index 37ea11a73d2a6..ab0e221928ff0 100644 --- a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala +++ b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornHashBasedColumnarShuffleWriter.scala @@ -16,14 +16,13 @@ */ package org.apache.spark.shuffle +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf import org.apache.gluten.GlutenConfig import org.apache.gluten.columnarbatch.ColumnarBatches -import org.apache.gluten.memory.memtarget.MemoryTarget -import org.apache.gluten.memory.memtarget.Spiller -import org.apache.gluten.memory.memtarget.Spillers -import org.apache.gluten.memory.nmm.NativeMemoryManagers +import org.apache.gluten.exec.Runtimes +import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers} import org.apache.gluten.vectorized._ - import org.apache.spark._ import org.apache.spark.memory.SparkMemoryUtil import org.apache.spark.scheduler.MapStatus @@ -31,11 +30,7 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SparkResourceUtil -import org.apache.celeborn.client.ShuffleClient -import org.apache.celeborn.common.CelebornConf - import java.io.IOException -import java.util class VeloxCelebornHashBasedColumnarShuffleWriter[K, V]( shuffleId: Int, @@ -44,15 +39,17 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V]( celebornConf: CelebornConf, client: ShuffleClient, writeMetrics: ShuffleWriteMetricsReporter) - extends CelebornHashBasedColumnarShuffleWriter[K, V]( - shuffleId, - handle, - context, - celebornConf, - client, - writeMetrics) { + extends CelebornHashBasedColumnarShuffleWriter[K, V]( + shuffleId, + handle, + context, + celebornConf, + client, + writeMetrics) { + + private val runtime = Runtimes.contextInstance("CelebornShuffleWriter") - private val jniWrapper = ShuffleWriterJniWrapper.create() + private val jniWrapper = ShuffleWriterJniWrapper.create(runtime) private var splitResult: SplitResult = _ @@ -105,38 +102,32 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V]( clientPushBufferMaxSize, clientPushSortMemoryThreshold, celebornPartitionPusher, - NativeMemoryManagers - .create( - "CelebornShuffleWriter", - new Spiller() { - override def spill(self: MemoryTarget, size: Long): Long = { - if (nativeShuffleWriter == -1L) { - throw new IllegalStateException( - "Fatal: spill() called before a celeborn shuffle writer " + - "is created. This behavior should be" + - "optimized by moving memory " + - "allocations from make() to split()") - } - logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data") - // fixme pass true when being called by self - val pushed = - jniWrapper.nativeEvict(nativeShuffleWriter, size, false) - logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of data") - pushed - } - - override def applicablePhases(): util.Set[Spiller.Phase] = - Spillers.PHASE_SET_SPILL_ONLY - } - ) - .getNativeInstanceHandle, handle, context.taskAttemptId(), GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, context.partitionId), "celeborn", shuffleWriterType, - GlutenConfig.getConf.columnarShuffleReallocThreshold - ) + GlutenConfig.getConf.columnarShuffleReallocThreshold) + runtime.addSpiller(new Spiller() { + override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = { + if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) { + return 0L + } + if (nativeShuffleWriter == -1L) { + throw new IllegalStateException( + "Fatal: spill() called before a celeborn shuffle writer " + + "is created. This behavior should be" + + "optimized by moving memory " + + "allocations from make() to split()") + } + logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data") + // fixme pass true when being called by self + val pushed = + jniWrapper.nativeEvict(nativeShuffleWriter, size, false) + logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of data") + pushed + } + }) } val startTime = System.nanoTime() jniWrapper.write(nativeShuffleWriter, cb.numRows, handle, availableOffHeapPerTask())