Skip to content

Commit

Permalink
feat: Remove COMET_EXEC_BROADCAST_ENABLED
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 18, 2024
1 parent 93e81cc commit 86167fd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
9 changes: 5 additions & 4 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,13 @@ object CometConf {
.booleanConf
.createWithDefault(false)

val COMET_EXEC_BROADCAST_ENABLED: ConfigEntry[Boolean] =
val COMET_EXEC_BROADCAST_FORCE_ENABLED: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.broadcast.enabled")
.doc(
"Whether to enable broadcasting for Comet native operators. By default, " +
"this config is false. Note that this feature is not fully supported yet " +
"and only enabled for test purpose.")
"Whether to force enabling broadcasting for Comet native operators. By default, " +
"this config is false. Comet broadcast feature will be enabled automatically by " +
"Comet extension. But for unit tests, we need this feature to force enabling it " +
"for invalid cases. So this config is only used for unit test.")
.booleanConf
.createWithDefault(false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle}
Expand All @@ -42,7 +41,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf._
import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported}
import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported}
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde
Expand Down Expand Up @@ -368,15 +367,26 @@ class CometSparkSessionExtensions
u
}

case b: BroadcastExchangeExec
if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") &&
isCometBroadCastEnabled(conf) =>
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
val cometOp = CometBroadcastExchangeExec(b, b.child)
CometSinkPlaceHolder(nativeOp, b, cometOp)
case None => b
// `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast
// exchange. It is only used for Comet native execution. We only transform Spark broadcast
// exchange to Comet broadcast exchange if its downstream is a Comet native plan or if the
// broadcast exchange is forced to be enabled by Comet config.
case plan
if (isCometNative(plan) || isCometBroadCastForceEnabled(conf)) &&
plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) =>
val newChildren = plan.children.map {
case b: BroadcastExchangeExec
if isCometNative(b.child) &&
isCometOperatorEnabled(conf, "broadcastExchangeExec") =>
QueryPlanSerde.operator2Proto(b) match {
case Some(nativeOp) =>
val cometOp = CometBroadcastExchangeExec(b, b.child)
CometSinkPlaceHolder(nativeOp, b, cometOp)
case None => b
}
case other => other
}
plan.withNewChildren(newChildren)

// Native shuffle for Comet operators
case s: ShuffleExchangeExec
Expand Down Expand Up @@ -547,11 +557,13 @@ object CometSparkSessionExtensions extends Logging {

private[comet] def isCometOperatorEnabled(conf: SQLConf, operator: String): Boolean = {
val operatorFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.enabled"
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf)
val operatorDisabledFlag = s"$COMET_EXEC_CONFIG_PREFIX.$operator.disabled"
conf.getConfString(operatorFlag, "false").toBoolean || isCometAllOperatorEnabled(conf) &&
!conf.getConfString(operatorDisabledFlag, "false").toBoolean
}

private[comet] def isCometBroadCastEnabled(conf: SQLConf): Boolean = {
COMET_EXEC_BROADCAST_ENABLED.get(conf)
private[comet] def isCometBroadCastForceEnabled(conf: SQLConf): Boolean = {
COMET_EXEC_BROADCAST_FORCE_ENABLED.get(conf)
}

private[comet] def isCometShuffleEnabled(conf: SQLConf): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class CometExecSuite extends CometTestBase {

test("CometBroadcastExchangeExec") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") {
withSQLConf(CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.key -> "true") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_b") {
val df = sql(
Expand All @@ -99,7 +99,7 @@ class CometExecSuite extends CometTestBase {

test("CometBroadcastExchangeExec: empty broadcast") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") {
withSQLConf(CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.key -> "true") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_b") {
val df = sql(
Expand Down

0 comments on commit 86167fd

Please sign in to comment.