Skip to content

Commit

Permalink
[CORE] Port "SPARK-39983 Should not cache unserialized broadcast rela…
Browse files Browse the repository at this point in the history
…tions on the driver" (#5149)
  • Loading branch information
ulysses-you authored Mar 28, 2024
1 parent 4456161 commit f528858
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
val broadcasted = GlutenTimeMetric.millis(longMetric("broadcastTime")) {
_ =>
// Broadcast the relation
sparkContext.broadcast(relation.asInstanceOf[Any])
SparkShimLoader.getSparkShims.broadcastInternal(
sparkContext,
relation.asInstanceOf[Any])
}

// Update driver metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,39 +45,7 @@ trait GlutenSQLTestsBaseTrait extends SharedSparkSession with GlutenTestsBaseTra
}

override def sparkConf: SparkConf = {
// Native SQL configs
val conf = super.sparkConf
.setAppName("Gluten-UT")
.set("spark.driver.memory", "1G")
.set("spark.sql.adaptive.enabled", "true")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.sql.files.maxPartitionBytes", "134217728")
.set("spark.memory.offHeap.enabled", "true")
.set("spark.memory.offHeap.size", "1024MB")
.set("spark.plugins", "io.glutenproject.GlutenPlugin")
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.sql.warehouse.dir", warehouse)
.set("spark.ui.enabled", "false")
.set("spark.gluten.ui.enabled", "false")
// Avoid static evaluation by spark catalyst. But there are some UT issues
// coming from spark, e.g., expecting SparkException is thrown, but the wrapped
// exception is thrown.
// .set("spark.sql.optimizer.excludedRules", ConstantFolding.ruleName + "," +
// NullPropagation.ruleName)

if (BackendTestUtils.isCHBackendLoaded()) {
conf
.set("spark.io.compression.codec", "LZ4")
.set("spark.gluten.sql.columnar.backend.ch.worker.id", "1")
.set("spark.gluten.sql.enable.native.validation", "false")
.set(GlutenConfig.GLUTEN_LIB_PATH, SystemParameters.getClickHouseLibPath)
.set("spark.sql.files.openCostInBytes", "134217728")
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
} else {
conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
}

conf
GlutenSQLTestsBaseTrait.nativeSparkConf(super.sparkConf, warehouse)
}

/**
Expand Down Expand Up @@ -126,3 +94,39 @@ trait GlutenSQLTestsBaseTrait extends SharedSparkSession with GlutenTestsBaseTra
}
}
}

object GlutenSQLTestsBaseTrait {
def nativeSparkConf(origin: SparkConf, warehouse: String): SparkConf = {
// Native SQL configs
val conf = origin
.setAppName("Gluten-UT")
.set("spark.driver.memory", "1G")
.set("spark.sql.adaptive.enabled", "true")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.sql.files.maxPartitionBytes", "134217728")
.set("spark.memory.offHeap.enabled", "true")
.set("spark.memory.offHeap.size", "1024MB")
.set("spark.plugins", "io.glutenproject.GlutenPlugin")
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.sql.warehouse.dir", warehouse)
.set("spark.ui.enabled", "false")
.set("spark.gluten.ui.enabled", "false")
// Avoid static evaluation by spark catalyst. But there are some UT issues
// coming from spark, e.g., expecting SparkException is thrown, but the wrapped
// exception is thrown.
// .set("spark.sql.optimizer.excludedRules", ConstantFolding.ruleName + "," +
// NullPropagation.ruleName)

if (BackendTestUtils.isCHBackendLoaded()) {
conf
.set("spark.io.compression.codec", "LZ4")
.set("spark.gluten.sql.columnar.backend.ch.worker.id", "1")
.set("spark.gluten.sql.enable.native.validation", "false")
.set(GlutenConfig.GLUTEN_LIB_PATH, SystemParameters.getClickHouseLibPath)
.set("spark.sql.files.openCostInBytes", "134217728")
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
} else {
conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenHashExpressionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite}
import org.apache.spark.sql.connector.{GlutenDataSourceV2DataFrameSessionCatalogSuite, GlutenDataSourceV2DataFrameSuite, GlutenDataSourceV2FunctionSuite, GlutenDataSourceV2SQLSessionCatalogSuite, GlutenDataSourceV2SQLSuiteV1Filter, GlutenDataSourceV2SQLSuiteV2Filter, GlutenDataSourceV2Suite, GlutenDeleteFromTableSuite, GlutenFileDataSourceV2FallBackSuite, GlutenKeyGroupedPartitioningSuite, GlutenLocalScanSuite, GlutenMetadataColumnSuite, GlutenSupportsCatalogOptionsSuite, GlutenTableCapabilityCheckSuite, GlutenWriteDistributionAndOrderingSuite}
import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite}
import org.apache.spark.sql.execution.{FallbackStrategiesSuite, GlutenBroadcastExchangeSuite, GlutenCoalesceShufflePartitionsSuite, GlutenExchangeSuite, GlutenReplaceHashWithSortAggSuite, GlutenReuseExchangeAndSubquerySuite, GlutenSameResultSuite, GlutenSortSuite, GlutenSQLAggregateFunctionSuite, GlutenSQLWindowFunctionSuite, GlutenTakeOrderedAndProjectSuite}
import org.apache.spark.sql.execution.{FallbackStrategiesSuite, GlutenBroadcastExchangeSuite, GlutenCoalesceShufflePartitionsSuite, GlutenExchangeSuite, GlutenLocalBroadcastExchangeSuite, GlutenReplaceHashWithSortAggSuite, GlutenReuseExchangeAndSubquerySuite, GlutenSameResultSuite, GlutenSortSuite, GlutenSQLAggregateFunctionSuite, GlutenSQLWindowFunctionSuite, GlutenTakeOrderedAndProjectSuite}
import org.apache.spark.sql.execution.adaptive.velox.VeloxAdaptiveQueryExecSuite
import org.apache.spark.sql.execution.datasources.{GlutenBucketingUtilsSuite, GlutenCSVReadSchemaSuite, GlutenDataSourceStrategySuite, GlutenDataSourceSuite, GlutenFileFormatWriterSuite, GlutenFileIndexSuite, GlutenFileMetadataStructSuite, GlutenFileSourceStrategySuite, GlutenHadoopFileLinesReaderSuite, GlutenHeaderCSVReadSchemaSuite, GlutenJsonReadSchemaSuite, GlutenMergedOrcReadSchemaSuite, GlutenMergedParquetReadSchemaSuite, GlutenOrcCodecSuite, GlutenOrcReadSchemaSuite, GlutenOrcV1AggregatePushDownSuite, GlutenOrcV2AggregatePushDownSuite, GlutenParquetCodecSuite, GlutenParquetReadSchemaSuite, GlutenParquetV1AggregatePushDownSuite, GlutenParquetV2AggregatePushDownSuite, GlutenPathFilterStrategySuite, GlutenPathFilterSuite, GlutenPruneFileSourcePartitionsSuite, GlutenV1WriteCommandSuite, GlutenVectorizedOrcReadSchemaSuite, GlutenVectorizedParquetReadSchemaSuite}
import org.apache.spark.sql.execution.datasources.binaryfile.GlutenBinaryFileFormatSuite
Expand Down Expand Up @@ -827,6 +827,7 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenOuterJoinSuite]
enableSuite[FallbackStrategiesSuite]
enableSuite[GlutenBroadcastExchangeSuite]
enableSuite[GlutenLocalBroadcastExchangeSuite]
enableSuite[GlutenCoalesceShufflePartitionsSuite]
.excludeByPrefix("determining the number of reducers")
enableSuite[GlutenExchangeSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,43 @@
*/
package org.apache.spark.sql.execution

import org.apache.spark.sql.GlutenSQLTestsBaseTrait
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.broadcast.TorrentBroadcast
import org.apache.spark.sql.{GlutenSQLTestsBaseTrait, GlutenTestsBaseTrait, SparkSession}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.broadcast

class GlutenBroadcastExchangeSuite extends BroadcastExchangeSuite with GlutenSQLTestsBaseTrait {}

// Additional tests run in 'local-cluster' mode.
class GlutenLocalBroadcastExchangeSuite
extends SparkFunSuite
with LocalSparkContext
with GlutenTestsBaseTrait
with AdaptiveSparkPlanHelper {

def newSparkConf(): SparkConf = {
val conf = new SparkConf().setMaster("local-cluster[2,1,1024]")
GlutenSQLTestsBaseTrait.nativeSparkConf(conf, warehouse)
}

test("SPARK-39983 - Broadcasted relation is not cached on the driver") {
// Use distributed cluster as in local mode the broabcast value is actually cached.
val conf = newSparkConf()
sc = new SparkContext(conf)
val spark = new SparkSession(sc)

val df = spark.range(1).toDF()
val joinDF = df.join(broadcast(df), "id")
joinDF.collect()
val broadcastExchangeExec = collect(joinDF.queryExecution.executedPlan) {
case p: ColumnarBroadcastExchangeExec => p
}
assert(broadcastExchangeExec.size == 1, "one and only ColumnarBroadcastExchangeExec")

// The broadcasted relation should not be cached on the driver.
val broadcasted =
broadcastExchangeExec(0).relationFuture.get().asInstanceOf[TorrentBroadcast[Any]]
assert(!broadcasted.hasCachedValue)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package io.glutenproject.sql.shims
import io.glutenproject.expression.Sig

import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.scheduler.TaskInfo
import org.apache.spark.shuffle.ShuffleHandle
Expand Down Expand Up @@ -46,6 +47,8 @@ import org.apache.hadoop.fs.{FileStatus, Path}

import java.util.{ArrayList => JArrayList, Map => JMap}

import scala.reflect.ClassTag

sealed abstract class ShimDescriptor

case class SparkShimDescriptor(major: Int, minor: Int, patch: Int) extends ShimDescriptor {
Expand Down Expand Up @@ -123,6 +126,13 @@ trait SparkShims {

def createTestTaskContext(): TaskContext

def broadcastInternal[T: ClassTag](sc: SparkContext, value: T): Broadcast[T] = {
// Since Spark 3.4, the `sc.broadcast` has been optimized to use `sc.broadcastInternal`.
// More details see SPARK-39983.
// TODO, remove this shim once we drop Spark3.3 and previous
sc.broadcast(value)
}

// To be compatible with Spark-3.5 and later
// See https://github.com/apache/spark/pull/41440
def setJobDescriptionOrTagForBroadcastExchange(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.expression.{ExpressionNames, Sig}
import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark.{ShuffleUtils, SparkContext, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark.{ShuffleUtils, SparkContext, SparkContextUtils, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.paths.SparkPath
import org.apache.spark.scheduler.TaskInfo
Expand Down Expand Up @@ -53,6 +54,8 @@ import org.apache.hadoop.fs.{FileStatus, Path}
import java.time.ZoneOffset
import java.util.{HashMap => JHashMap, Map => JMap}

import scala.reflect.ClassTag

class Spark34Shims extends SparkShims {
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR

Expand Down Expand Up @@ -241,6 +244,10 @@ class Spark34Shims extends SparkShims {
TaskContextUtils.createTestTaskContext()
}

override def broadcastInternal[T: ClassTag](sc: SparkContext, value: T): Broadcast[T] = {
SparkContextUtils.broadcastInternal(sc, value)
}

def setJobDescriptionOrTagForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark

import org.apache.spark.broadcast.Broadcast

import scala.reflect.ClassTag

object SparkContextUtils {
def broadcastInternal[T: ClassTag](sc: SparkContext, value: T): Broadcast[T] = {
sc.broadcastInternal(value, serializedOnly = true)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import io.glutenproject.GlutenConfig
import io.glutenproject.expression.{ExpressionNames, Sig}
import io.glutenproject.sql.shims.{ShimDescriptor, SparkShims}

import org.apache.spark.{ShuffleUtils, SparkContext, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark.{ShuffleUtils, SparkContext, SparkContextUtils, SparkException, TaskContext, TaskContextUtils}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.paths.SparkPath
import org.apache.spark.scheduler.TaskInfo
Expand Down Expand Up @@ -52,6 +53,8 @@ import org.apache.hadoop.fs.{FileStatus, Path}
import java.time.ZoneOffset
import java.util.{HashMap => JHashMap, Map => JMap}

import scala.reflect.ClassTag

class Spark35Shims extends SparkShims {
override def getShimDescriptor: ShimDescriptor = SparkShimProvider.DESCRIPTOR

Expand Down Expand Up @@ -238,6 +241,10 @@ class Spark35Shims extends SparkShims {
TaskContextUtils.createTestTaskContext()
}

override def broadcastInternal[T: ClassTag](sc: SparkContext, value: T): Broadcast[T] = {
SparkContextUtils.broadcastInternal(sc, value)
}

override def setJobDescriptionOrTagForBroadcastExchange(
sc: SparkContext,
broadcastExchange: BroadcastExchangeLike): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark

import org.apache.spark.broadcast.Broadcast

import scala.reflect.ClassTag

object SparkContextUtils {
def broadcastInternal[T: ClassTag](sc: SparkContext, value: T): Broadcast[T] = {
sc.broadcastInternal(value, serializedOnly = true)
}
}

0 comments on commit f528858

Please sign in to comment.