Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Jun 27, 2024
1 parent f43313b commit 609a976
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
import org.apache.gluten.memory.memtarget.MemoryTargets;
import org.apache.gluten.memory.memtarget.Spiller;
import org.apache.gluten.memory.memtarget.Spillers;

import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.TaskResources;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/**
* Built-in toolkit for managing native memory allocations. To use the facility, one should import
Expand All @@ -46,12 +45,12 @@ private CHNativeMemoryAllocators() {}
private static CHNativeMemoryAllocatorManager createNativeMemoryAllocatorManager(
String name,
TaskMemoryManager taskMemoryManager,
List<Spiller> spillers,
Spiller spiller,
SimpleMemoryUsageRecorder usage) {

CHManagedCHReservationListener rl =
new CHManagedCHReservationListener(
MemoryTargets.newConsumer(taskMemoryManager, name, spillers, Collections.emptyMap()),
MemoryTargets.newConsumer(taskMemoryManager, name, spiller, Collections.emptyMap()),
usage);
return new CHNativeMemoryAllocatorManagerImpl(CHNativeMemoryAllocator.createListenable(rl));
}
Expand All @@ -67,7 +66,7 @@ public static CHNativeMemoryAllocator contextInstance() {
createNativeMemoryAllocatorManager(
"ContextInstance",
TaskResources.getLocalTaskContext().taskMemoryManager(),
Collections.emptyList(),
Spillers.NOOP,
TaskResources.getSharedUsage());
TaskResources.addResource(id, manager);
}
Expand All @@ -78,7 +77,7 @@ public static CHNativeMemoryAllocator contextInstanceForUT() {
return CHNativeMemoryAllocator.getDefaultForUT();
}

public static CHNativeMemoryAllocator createSpillable(String name, Spiller... spillers) {
public static CHNativeMemoryAllocator createSpillable(String name, Spiller spiller) {
if (!TaskResources.inSparkTask()) {
throw new IllegalStateException("spiller must be used in a Spark task");
}
Expand All @@ -87,7 +86,7 @@ public static CHNativeMemoryAllocator createSpillable(String name, Spiller... sp
createNativeMemoryAllocatorManager(
name,
TaskResources.getLocalTaskContext().taskMemoryManager(),
Arrays.asList(spillers),
spiller,
TaskResources.getSharedUsage());
TaskResources.addAnonymousResource(manager);
// force add memory consumer to task memory manager, will release by inactivate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SparkDirectoryUtil, Utils}

import java.io.IOException
import java.util
import java.util.{Locale, UUID}

class CHColumnarShuffleWriter[K, V](
Expand Down Expand Up @@ -122,7 +121,10 @@ class CHColumnarShuffleWriter[K, V](
CHNativeMemoryAllocators.createSpillable(
"ShuffleWriter",
new Spiller() {
override def spill(self: MemoryTarget, size: Long): Long = {
override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = {
if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
return 0L;
}
if (nativeSplitter == 0) {
throw new IllegalStateException(
"Fatal: spill() called before a shuffle writer " +
Expand All @@ -134,8 +136,6 @@ class CHColumnarShuffleWriter[K, V](
logError(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data")
spilled
}

override def applicablePhases(): util.Set[Spiller.Phase] = Spillers.PHASE_SET_SPILL_ONLY
}
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.gluten.memory.alloc.CHNativeMemoryAllocator;
import org.apache.gluten.memory.alloc.CHNativeMemoryAllocatorManagerImpl;
import org.apache.gluten.memory.memtarget.MemoryTargets;
import org.apache.gluten.memory.memtarget.Spillers;

import org.apache.spark.SparkConf;
import org.apache.spark.internal.config.package$;
Expand Down Expand Up @@ -52,7 +53,7 @@ public void initMemoryManager() {
listener =
new CHManagedCHReservationListener(
MemoryTargets.newConsumer(
taskMemoryManager, "test", Collections.emptyList(), Collections.emptyMap()),
taskMemoryManager, "test", Spillers.NOOP, Collections.emptyMap()),
new SimpleMemoryUsageRecorder());

manager = new CHNativeMemoryAllocatorManagerImpl(new CHNativeMemoryAllocator(-1L, listener));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,31 @@ case class CachedColumnarBatch(
// spotless:off
/**
* Feature:
* 1. This serializer supports column pruning 2. TODO: support push down filter 3. Super TODO:
* support store offheap object directly
* 1. This serializer supports column pruning
* 2. TODO: support push down filter
* 3. Super TODO: support store offheap object directly
*
* The data transformation pipeline:
*
* - Serializer ColumnarBatch -> CachedColumnarBatch
* -> serialize to byte[]
* -> serialize to byte[]
*
* - Deserializer CachedColumnarBatch -> ColumnarBatch
* -> deserialize to byte[] to create Velox ColumnarBatch
* -> deserialize to byte[] to create Velox ColumnarBatch
*
* - Serializer InternalRow -> CachedColumnarBatch (support RowToColumnar)
* -> Convert InternalRow to ColumnarBatch
* -> Serializer ColumnarBatch -> CachedColumnarBatch
* -> Convert InternalRow to ColumnarBatch
* -> Serializer ColumnarBatch -> CachedColumnarBatch
*
* - Serializer InternalRow -> DefaultCachedBatch (unsupport RowToColumnar)
* -> Convert InternalRow to DefaultCachedBatch using vanilla Spark serializer
* -> Convert InternalRow to DefaultCachedBatch using vanilla Spark serializer
*
* - Deserializer CachedColumnarBatch -> InternalRow (support ColumnarToRow)
* -> Deserializer CachedColumnarBatch -> ColumnarBatch
* -> Convert ColumnarBatch to InternalRow
* -> Deserializer CachedColumnarBatch -> ColumnarBatch
* -> Convert ColumnarBatch to InternalRow
*
* - Deserializer DefaultCachedBatch -> InternalRow (unsupport ColumnarToRow)
* -> Convert DefaultCachedBatch to InternalRow using vanilla Spark serializer
* -> Convert DefaultCachedBatch to InternalRow using vanilla Spark serializer
*/
// spotless:on
class ColumnarCachedBatchSerializer extends CachedBatchSerializer with SQLConfHelper with Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,19 @@
*/
package org.apache.spark.shuffle

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.gluten.memory.alloc.CHNativeMemoryAllocators
import org.apache.gluten.memory.memtarget.MemoryTarget
import org.apache.gluten.memory.memtarget.Spiller
import org.apache.gluten.memory.memtarget.Spillers
import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers}
import org.apache.gluten.vectorized._

import org.apache.spark._
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf

import java.io.IOException
import java.util
import java.util.Locale

class CHCelebornHashBasedColumnarShuffleWriter[K, V](
Expand All @@ -43,13 +38,13 @@ class CHCelebornHashBasedColumnarShuffleWriter[K, V](
celebornConf: CelebornConf,
client: ShuffleClient,
writeMetrics: ShuffleWriteMetricsReporter)
extends CelebornHashBasedColumnarShuffleWriter[K, V](
shuffleId: Int,
handle,
context,
celebornConf,
client,
writeMetrics) {
extends CelebornHashBasedColumnarShuffleWriter[K, V](
shuffleId: Int,
handle,
context,
celebornConf,
client,
writeMetrics) {

private val customizedCompressCodec =
customizedCompressionCodec.toUpperCase(Locale.ROOT)
Expand Down Expand Up @@ -80,12 +75,14 @@ class CHCelebornHashBasedColumnarShuffleWriter[K, V](
GlutenConfig.getConf.chColumnarThrowIfMemoryExceed,
GlutenConfig.getConf.chColumnarFlushBlockBufferBeforeEvict,
GlutenConfig.getConf.chColumnarForceExternalSortShuffle,
GlutenConfig.getConf.chColumnarForceMemorySortShuffle
)
GlutenConfig.getConf.chColumnarForceMemorySortShuffle)
CHNativeMemoryAllocators.createSpillable(
"CelebornShuffleWriter",
new Spiller() {
override def spill(self: MemoryTarget, size: Long): Long = {
override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = {
if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
return 0L
}
if (nativeShuffleWriter == -1L) {
throw new IllegalStateException(
"Fatal: spill() called before a celeborn shuffle writer " +
Expand All @@ -98,10 +95,7 @@ class CHCelebornHashBasedColumnarShuffleWriter[K, V](
logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of data")
spilled
}

override def applicablePhases(): util.Set[Spiller.Phase] = Spillers.PHASE_SET_SPILL_ONLY
}
)
})
}
while (records.hasNext) {
val cb = records.next()._2.asInstanceOf[ColumnarBatch]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,21 @@
*/
package org.apache.spark.shuffle

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.gluten.GlutenConfig
import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.memory.memtarget.MemoryTarget
import org.apache.gluten.memory.memtarget.Spiller
import org.apache.gluten.memory.memtarget.Spillers
import org.apache.gluten.memory.nmm.NativeMemoryManagers
import org.apache.gluten.exec.Runtimes
import org.apache.gluten.memory.memtarget.{MemoryTarget, Spiller, Spillers}
import org.apache.gluten.vectorized._

import org.apache.spark._
import org.apache.spark.memory.SparkMemoryUtil
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.SparkResourceUtil

import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf

import java.io.IOException
import java.util

class VeloxCelebornHashBasedColumnarShuffleWriter[K, V](
shuffleId: Int,
Expand All @@ -44,15 +39,17 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V](
celebornConf: CelebornConf,
client: ShuffleClient,
writeMetrics: ShuffleWriteMetricsReporter)
extends CelebornHashBasedColumnarShuffleWriter[K, V](
shuffleId,
handle,
context,
celebornConf,
client,
writeMetrics) {
extends CelebornHashBasedColumnarShuffleWriter[K, V](
shuffleId,
handle,
context,
celebornConf,
client,
writeMetrics) {

private val runtime = Runtimes.contextInstance("CelebornShuffleWriter")

private val jniWrapper = ShuffleWriterJniWrapper.create()
private val jniWrapper = ShuffleWriterJniWrapper.create(runtime)

private var splitResult: SplitResult = _

Expand Down Expand Up @@ -105,38 +102,32 @@ class VeloxCelebornHashBasedColumnarShuffleWriter[K, V](
clientPushBufferMaxSize,
clientPushSortMemoryThreshold,
celebornPartitionPusher,
NativeMemoryManagers
.create(
"CelebornShuffleWriter",
new Spiller() {
override def spill(self: MemoryTarget, size: Long): Long = {
if (nativeShuffleWriter == -1L) {
throw new IllegalStateException(
"Fatal: spill() called before a celeborn shuffle writer " +
"is created. This behavior should be" +
"optimized by moving memory " +
"allocations from make() to split()")
}
logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data")
// fixme pass true when being called by self
val pushed =
jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of data")
pushed
}

override def applicablePhases(): util.Set[Spiller.Phase] =
Spillers.PHASE_SET_SPILL_ONLY
}
)
.getNativeInstanceHandle,
handle,
context.taskAttemptId(),
GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, context.partitionId),
"celeborn",
shuffleWriterType,
GlutenConfig.getConf.columnarShuffleReallocThreshold
)
GlutenConfig.getConf.columnarShuffleReallocThreshold)
runtime.addSpiller(new Spiller() {
override def spill(self: MemoryTarget, phase: Spiller.Phase, size: Long): Long = {
if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
return 0L
}
if (nativeShuffleWriter == -1L) {
throw new IllegalStateException(
"Fatal: spill() called before a celeborn shuffle writer " +
"is created. This behavior should be" +
"optimized by moving memory " +
"allocations from make() to split()")
}
logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data")
// fixme pass true when being called by self
val pushed =
jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of data")
pushed
}
})
}
val startTime = System.nanoTime()
jniWrapper.write(nativeShuffleWriter, cb.numRows, handle, availableOffHeapPerTask())
Expand Down

0 comments on commit 609a976

Please sign in to comment.