Skip to content

Commit

Permalink
[GLUTEN-7750][VL] Move ColumnarBuildSideRelation's memory occupation …
Browse files Browse the repository at this point in the history
…to Spark off-heap (#8127)
  • Loading branch information
zjuwangg authored Jan 2, 2025
1 parent ff89539 commit dda601b
Show file tree
Hide file tree
Showing 10 changed files with 612 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.sql.execution.utils.ExecUtil
import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction}
import org.apache.spark.sql.hive.VeloxHiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.task.TaskResources

import org.apache.commons.lang3.ClassUtils

Expand Down Expand Up @@ -621,6 +623,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
child: SparkPlan,
numOutputRows: SQLMetric,
dataSize: SQLMetric): BuildSideRelation = {
val useOffheapBroadcastBuildRelation =
GlutenConfig.getConf.enableBroadcastBuildRelationInOffheap
val serialized: Array[ColumnarBatchSerializeResult] = child
.executeColumnar()
.mapPartitions(itr => Iterator(BroadcastUtils.serializeStream(itr)))
Expand All @@ -633,7 +637,13 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
}
numOutputRows += serialized.map(_.getNumRows).sum
dataSize += rawSize
ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode)
if (useOffheapBroadcastBuildRelation) {
TaskResources.runUnsafe {
new UnsafeColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode)
}
} else {
ColumnarBuildSideRelation(child.output, serialized.map(_.getSerialized), mode)
}
}

override def doCanonicalizeForBroadcastMode(mode: BroadcastMode): BroadcastMode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.runtime.Runtimes
Expand All @@ -27,7 +28,8 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning}
import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode, LongHashedRelation}
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelation, HashedRelationBroadcastMode, LongHashedRelation}
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.task.TaskResources
Expand All @@ -45,7 +47,7 @@ object BroadcastUtils {
mode match {
case HashedRelationBroadcastMode(_, _) =>
// ColumnarBuildSideRelation to HashedRelation.
val fromBroadcast = from.asInstanceOf[Broadcast[ColumnarBuildSideRelation]]
val fromBroadcast = from.asInstanceOf[Broadcast[BuildSideRelation]]
val fromRelation = fromBroadcast.value.asReadOnlyCopy()
var rowCount: Long = 0
val toRelation = TaskResources.runUnsafe {
Expand All @@ -60,7 +62,7 @@ object BroadcastUtils {
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
case IdentityBroadcastMode =>
// ColumnarBuildSideRelation to HashedRelation.
val fromBroadcast = from.asInstanceOf[Broadcast[ColumnarBuildSideRelation]]
val fromBroadcast = from.asInstanceOf[Broadcast[BuildSideRelation]]
val fromRelation = fromBroadcast.value.asReadOnlyCopy()
val toRelation = TaskResources.runUnsafe {
val rowIterator = fn(fromRelation.deserialized)
Expand Down Expand Up @@ -91,6 +93,7 @@ object BroadcastUtils {
schema: StructType,
from: Broadcast[F],
fn: Iterator[InternalRow] => Iterator[ColumnarBatch]): Broadcast[T] = {
val useOffheapBuildRelation = GlutenConfig.getConf.enableBroadcastBuildRelationInOffheap
mode match {
case HashedRelationBroadcastMode(_, _) =>
// HashedRelation to ColumnarBuildSideRelation.
Expand All @@ -104,10 +107,17 @@ object BroadcastUtils {
case result: ColumnarBatchSerializeResult =>
Array(result.getSerialized)
}
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
if (useOffheapBuildRelation) {
new UnsafeColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
} else {
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
}
}
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
Expand All @@ -123,10 +133,17 @@ object BroadcastUtils {
case result: ColumnarBatchSerializeResult =>
Array(result.getSerialized)
}
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
if (useOffheapBuildRelation) {
new UnsafeColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
} else {
ColumnarBuildSideRelation(
SparkShimLoader.getSparkShims.attributesFromStruct(schema),
serialized,
mode)
}
}
// Rebroadcast Velox relation.
context.broadcast(toRelation).asInstanceOf[Broadcast[T]]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.unsafe

import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.LongArray

/**
* Used to store broadcast variable off-heap memory for broadcast variable. The underlying data
* structure is a LongArray allocated in off-heap memory.
*
* @param arraySize
* underlying array[array[byte]]'s length
* @param bytesBufferLengths
* underlying array[array[byte]] per bytesBuffer length
* @param totalBytes
* all bytesBuffer's length plus together
*/
// scalastyle:off no.finalize
@Experimental
case class UnsafeBytesBufferArray(
arraySize: Int,
bytesBufferLengths: Array[Int],
totalBytes: Long,
tmm: TaskMemoryManager)
extends MemoryConsumer(tmm, MemoryMode.OFF_HEAP)
with Logging {

{
assert(
arraySize == bytesBufferLengths.length,
"Unsafe buffer array size " +
"not equal to buffer lengths!")
assert(totalBytes >= 0, "Unsafe buffer array total bytes can't be negative!")
}

/**
* A single array to store all bytesBufferArray's value, it's inited once when first time get
* accessed.
*/
private var longArray: LongArray = _

/** Index the start of each byteBuffer's offset to underlying LongArray's initial position. */
private val bytesBufferOffset = if (bytesBufferLengths.isEmpty) {
new Array(0)
} else {
bytesBufferLengths.init.scanLeft(0)(_ + _)
}

override def spill(l: Long, memoryConsumer: MemoryConsumer): Long = 0L

/**
* Put bytesBuffer at specified array index.
*
* @param index
* index of the array.
* @param bytesBuffer
* bytesBuffer to put.
*/
def putBytesBuffer(index: Int, bytesBuffer: Array[Byte]): Unit = this.synchronized {
assert(index < arraySize)
assert(bytesBuffer.length == bytesBufferLengths(index))
// first to allocate underlying long array
if (null == longArray && index == 0) {
longArray = allocateArray((totalBytes + 7) / 8)
}

Platform.copyMemory(
bytesBuffer,
Platform.BYTE_ARRAY_OFFSET,
longArray.getBaseObject,
longArray.getBaseOffset + bytesBufferOffset(index),
bytesBufferLengths(index))
}

/**
* Get bytesBuffer at specified index.
* @param index
* @return
*/
def getBytesBuffer(index: Int): Array[Byte] = {
assert(index < arraySize)
if (null == longArray) {
return new Array[Byte](0)
}
val bytes = new Array[Byte](bytesBufferLengths(index))
Platform.copyMemory(
longArray.getBaseObject,
longArray.getBaseOffset + bytesBufferOffset(index),
bytes,
Platform.BYTE_ARRAY_OFFSET,
bytesBufferLengths(index))
bytes
}

/**
* Get the bytesBuffer memory address and length at specified index, usually used when read memory
* direct from offheap.
*
* @param index
* @return
*/
def getBytesBufferOffsetAndLength(index: Int): (Long, Int) = {
assert(index < arraySize)
assert(longArray != null, "The broadcast data in offheap should not be null!")
val offset = longArray.getBaseOffset + bytesBufferOffset(index)
val length = bytesBufferLengths(index)
(offset, length)
}

/**
* It's needed once the broadcast variable is garbage collected. Since now, we don't have an
* elegant way to free the underlying memory in offheap.
*/
override def finalize(): Unit = {
try {
if (longArray != null) {
freeArray(longArray)
longArray = null
}
} finally {
super.finalize()
}
}
}
// scalastyle:on no.finalize
Loading

0 comments on commit dda601b

Please sign in to comment.