Skip to content

Commit

Permalink
avoid copy memory between offheap and onheap
Browse files Browse the repository at this point in the history
  • Loading branch information
zjuwangg committed Dec 4, 2024
1 parent 26aa1e5 commit 6a0e31f
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.gluten.expression._
import org.apache.gluten.expression.aggregate.{HLLAdapter, VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.vectorized.{ColumnarBatchSerializeResult, ColumnarBatchSerializer}
import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSerializeResult}

import org.apache.spark.{ShuffleDependency, SparkException}
import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEvalPythonPreProjectHelper}
import org.apache.spark.rdd.RDD
Expand All @@ -44,17 +45,18 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.sql.execution.utils.ExecUtil
import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction}
import org.apache.spark.sql.hive.VeloxHiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.commons.lang3.ClassUtils
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.task.TaskResources
import org.apache.spark.util.TaskResources

import org.apache.commons.lang3.ClassUtils

import javax.ws.rs.core.UriBuilder

class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Expand Down Expand Up @@ -621,8 +623,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
child: SparkPlan,
numOutputRows: SQLMetric,
dataSize: SQLMetric): BuildSideRelation = {
val useOffheapBroadcastBuildRelation = GlutenConfig.getConf
.enableBroadcastBuildRelationInOffheap
val useOffheapBroadcastBuildRelation =
GlutenConfig.getConf.enableBroadcastBuildRelationInOffheap
val serialized: Array[ColumnarBatchSerializeResult] = child
.executeColumnar()
.mapPartitions(itr => Iterator(BroadcastUtils.serializeStream(itr)))
Expand All @@ -635,7 +637,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}
numOutputRows += serialized.map(_.getNumRows).sum
dataSize += rawSize
if (useOffheapBroadcastBuildRelation){
if (useOffheapBroadcastBuildRelation) {
TaskResources.runUnsafe {
new UnsafeColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.vectorized.{ColumnarBatchSerializeResult, ColumnarBatchSerializerJniWrapper}

import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -134,9 +135,9 @@ object BroadcastUtils {
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized)
} else {
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized)
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized)
}
}
// Rebroadcast Velox relation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,31 @@ import org.apache.spark.unsafe.array.LongArray
import java.security.MessageDigest

/**
* Used to store broadcast variable off-heap memory for broadcast variable.
* The underlying data structure is a
* Used to store broadcast variable off-heap memory for broadcast variable. The underlying data
* structure is a
*
* @param arraySize
* underlying array[array[byte]]'s length
* underlying array[array[byte]]'s length
* @param bytesBufferLengths
* underlying array[array[byte]] per bytesBuffer length
* underlying array[array[byte]] per bytesBuffer length
* @param totalBytes
* all bytesBuffer's length plus together
* all bytesBuffer's length plus together
*/
case class UnsafeBytesBufferArray(
arraySize: Int,
bytesBufferLengths: Array[Int],
totalBytes: Long,
tmm: TaskMemoryManager)
arraySize: Int,
bytesBufferLengths: Array[Int],
totalBytes: Long,
tmm: TaskMemoryManager)
extends MemoryConsumer(tmm, MemoryMode.OFF_HEAP)
with Logging {
with Logging {

/**
* A single array to store all bytesBufferArray's value, it's inited once
* when first time get accessed.
* A single array to store all bytesBufferArray's value, it's inited once when first time get
* accessed.
*/
private var longArray: LongArray = _

/**
* Index the start of each byteBuffer's offset to underlying LongArray's initial position.
*/
/** Index the start of each byteBuffer's offset to underlying LongArray's initial position. */
private val bytesBufferOffset = new Array[Int](arraySize)

{
Expand Down Expand Up @@ -137,17 +136,17 @@ case class UnsafeBytesBufferArray(
* It's needed once the broadcast variable is garbage collected. Since now, we don't have an
* elegant way to free the underlying memory in offheap.
*/
override def finalize(): Unit = {
try {
if (longArray != null) {
log.debug(s"BytesArrayInOffheap finalize $arraySize")
freeArray(longArray)
longArray = null
}
} finally {
super.finalize()
override def finalize(): Unit = {
try {
if (longArray != null) {
log.debug(s"BytesArrayInOffheap finalize $arraySize")
freeArray(longArray)
longArray = null
}
} finally {
super.finalize()
}
}

/**
* Used to debug input/output bytes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@
*/
package org.apache.spark.sql.execution.unsafe

import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import org.apache.arrow.c.ArrowSchema
import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.ArrowAbiUtil
import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
Expand All @@ -38,24 +36,29 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.task.TaskResources
import org.apache.spark.util.Utils

import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.arrow.c.ArrowSchema

import java.io.{Externalizable, ObjectInput, ObjectOutput}

import scala.collection.JavaConverters.asScalaIteratorConverter

/**
* UnsafeColumnarBuildSideRelation should backed by offheap to avoid on-heap oom.
* Almost the same as ColumnarBuildSideRelation, we should remove ColumnarBuildSideRelation when
* UnsafeColumnarBuildSideRelation should backed by offheap to avoid on-heap oom. Almost the same as
* ColumnarBuildSideRelation, we should remove ColumnarBuildSideRelation when
* UnsafeColumnarBuildSideRelation get matured.
*
* @param output
* @param batches
*/
case class UnsafeColumnarBuildSideRelation(
private var output: Seq[Attribute],
private var batches: UnsafeBytesBufferArray)
private var output: Seq[Attribute],
private var batches: UnsafeBytesBufferArray)
extends BuildSideRelation
with Externalizable
with Logging
with KryoSerializable {
with Externalizable
with Logging
with KryoSerializable {

def this(output: Seq[Attribute], bytesBufferArray: Array[Array[Byte]]) {
// only used in driver side when broadcast the whole batches
Expand Down Expand Up @@ -113,11 +116,8 @@ case class UnsafeColumnarBuildSideRelation(
new UnifiedMemoryManager(SparkEnv.get.conf, Long.MaxValue, Long.MaxValue / 2, 1),
0)

batches = UnsafeBytesBufferArray(
totalArraySize,
bytesBufferLengths,
totalBytes,
taskMemoryManager)
batches =
UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths, totalBytes, taskMemoryManager)

for (i <- 0 until totalArraySize) {
val length = bytesBufferLengths(i)
Expand All @@ -138,11 +138,8 @@ case class UnsafeColumnarBuildSideRelation(
new UnifiedMemoryManager(SparkEnv.get.conf, Long.MaxValue, Long.MaxValue / 2, 1),
0)

batches = UnsafeBytesBufferArray(
totalArraySize,
bytesBufferLengths,
totalBytes,
taskMemoryManager)
batches =
UnsafeBytesBufferArray(totalArraySize, bytesBufferLengths, totalBytes, taskMemoryManager)

for (i <- 0 until totalArraySize) {
val length = bytesBufferLengths(i)
Expand All @@ -153,7 +150,6 @@ case class UnsafeColumnarBuildSideRelation(
}
}


override def deserialized: Iterator[ColumnarBatch] = {
val runtime = Runtimes.contextInstance("UnsafeBuildSideRelation#deserialized")
val jniWrapper = ColumnarBatchSerializerJniWrapper.create(runtime)
Expand All @@ -179,10 +175,11 @@ case class UnsafeColumnarBuildSideRelation(
}

override def next: ColumnarBatch = {
val handle =
jniWrapper
.deserialize(serializeHandle, batches.getBytesBuffer(batchId))
val (offset, length) =
batches.getBytesBufferOffsetAndLength(batchId)
batchId += 1
val handle =
jniWrapper.deserialize(serializeHandle, offset, length)
ColumnarBatches.create(handle)
}
})
Expand Down Expand Up @@ -247,10 +244,10 @@ case class UnsafeColumnarBuildSideRelation(
}

override def next(): Iterator[InternalRow] = {
val batchBytes = batches.getBytesBuffer(batchId)
val (offset, length) = batches.getBytesBufferOffsetAndLength(batchId)
batchId += 1
val batchHandle =
serializerJniWrapper.deserialize(serializeHandle, batchBytes)
serializerJniWrapper.deserialize(serializeHandle, offset, length)
val batch = ColumnarBatches.create(batchHandle)
if (batch.numRows == 0) {
batch.close()
Expand Down
16 changes: 16 additions & 0 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,22 @@ JNIEXPORT jlong JNICALL Java_org_apache_gluten_vectorized_ColumnarBatchSerialize
JNI_METHOD_END(kInvalidObjectHandle)
}

JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ColumnarBatchSerializerJniWrapper_deserialize( // NOLINT
JNIEnv* env,
jobject wrapper,
jlong serializerHandle,
jlong address,
jint size) {
JNI_METHOD_START
auto ctx = gluten::getRuntime(env, wrapper);

auto serializer = ctx->objectStore()->retrieve<ColumnarBatchSerializer>(serializerHandle);
GLUTEN_DCHECK(serializer != nullptr, "ColumnarBatchSerializer cannot be null");
auto batch = serializer->deserialize((uint8_t*) address, size);
return ctx->saveObject(batch);
JNI_METHOD_END(kInvalidResourceHandle)
}

JNIEXPORT void JNICALL Java_org_apache_gluten_vectorized_ColumnarBatchSerializerJniWrapper_close( // NOLINT
JNIEnv* env,
jobject wrapper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,8 @@ public long rtHandle() {

public native long deserialize(long serializerHandle, byte[] data);

// Return the native ColumnarBatch handle using memory address and length
public native long deserialize(long serializerHandle, long offset, int len);

public native void close(long serializerHandle);
}

0 comments on commit 6a0e31f

Please sign in to comment.